diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from contextlib import contextmanager
4
+
5
+
6
+ # mofified from transformers.modeling_utils
7
+ TORCH_INIT_FUNCTIONS = {
8
+ "uniform_": nn.init.uniform_,
9
+ "normal_": nn.init.normal_,
10
+ "trunc_normal_": nn.init.trunc_normal_,
11
+ "constant_": nn.init.constant_,
12
+ "xavier_uniform_": nn.init.xavier_uniform_,
13
+ "xavier_normal_": nn.init.xavier_normal_,
14
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
15
+ "kaiming_normal_": nn.init.kaiming_normal_,
16
+ "uniform": nn.init.uniform,
17
+ "normal": nn.init.normal,
18
+ "xavier_uniform": nn.init.xavier_uniform,
19
+ "xavier_normal": nn.init.xavier_normal,
20
+ "kaiming_uniform": nn.init.kaiming_uniform,
21
+ "kaiming_normal": nn.init.kaiming_normal,
22
+ }
23
+
24
+ _init_weights = True
25
+
26
+
27
+ @contextmanager
28
+ def no_init_weights():
29
+ """
30
+ Context manager to globally disable weight initialization to speed up loading large models.
31
+ """
32
+ global _init_weights
33
+ old_init_weights = _init_weights
34
+
35
+ def _skip_init(*args, **kwargs):
36
+ pass
37
+
38
+ _init_weights = False
39
+ # Save the original initialization functions
40
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
41
+ setattr(torch.nn.init, name, _skip_init)
42
+ try:
43
+ yield
44
+ finally:
45
+ _init_weights = old_init_weights
46
+ # Restore the original initialization functions
47
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
48
+ setattr(torch.nn.init, name, init_func)
49
+
50
+
51
+ def zero_module(module: nn.Module):
52
+ for p in module.parameters():
53
+ nn.init.zeros_(p)
54
+ return module
File without changes
@@ -0,0 +1,200 @@
1
+ import torch
2
+ import warnings
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+
19
+ def flash_attention(
20
+ q,
21
+ k,
22
+ v,
23
+ q_lens=None,
24
+ k_lens=None,
25
+ dropout_p=0.0,
26
+ softmax_scale=None,
27
+ q_scale=None,
28
+ causal=False,
29
+ window_size=(-1, -1),
30
+ deterministic=False,
31
+ dtype=torch.bfloat16,
32
+ version=None,
33
+ ):
34
+ """
35
+ q: [B, Lq, Nq, C1].
36
+ k: [B, Lk, Nk, C1].
37
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
38
+ q_lens: [B].
39
+ k_lens: [B].
40
+ dropout_p: float. Dropout probability.
41
+ softmax_scale: float. The scaling of QK^T before applying softmax.
42
+ causal: bool. Whether to apply causal attention mask.
43
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
44
+ deterministic: bool. If True, slightly slower and uses more memory.
45
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
46
+ """
47
+ half_dtypes = (torch.float16, torch.bfloat16)
48
+ assert dtype in half_dtypes
49
+ assert q.device.type == "cuda" and q.size(-1) <= 256
50
+
51
+ # params
52
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
53
+
54
+ def half(x):
55
+ return x if x.dtype in half_dtypes else x.to(dtype)
56
+
57
+ # preprocess query
58
+ if q_lens is None:
59
+ q = half(q.flatten(0, 1))
60
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
61
+ else:
62
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
63
+
64
+ # preprocess key, value
65
+ if k_lens is None:
66
+ k = half(k.flatten(0, 1))
67
+ v = half(v.flatten(0, 1))
68
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
69
+ else:
70
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
71
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
72
+
73
+ q = q.to(v.dtype)
74
+ k = k.to(v.dtype)
75
+
76
+ if q_scale is not None:
77
+ q = q * q_scale
78
+
79
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
80
+ warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
81
+
82
+ # apply attention
83
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
84
+ # Note: dropout_p, window_size are not supported in FA3 now.
85
+ x = flash_attn_interface.flash_attn_varlen_func(
86
+ q=q,
87
+ k=k,
88
+ v=v,
89
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
90
+ .cumsum(0, dtype=torch.int32)
91
+ .to(q.device, non_blocking=True),
92
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
93
+ .cumsum(0, dtype=torch.int32)
94
+ .to(q.device, non_blocking=True),
95
+ seqused_q=None,
96
+ seqused_k=None,
97
+ max_seqlen_q=lq,
98
+ max_seqlen_k=lk,
99
+ softmax_scale=softmax_scale,
100
+ causal=causal,
101
+ deterministic=deterministic,
102
+ )[0].unflatten(0, (b, lq))
103
+ elif FLASH_ATTN_2_AVAILABLE:
104
+ x = flash_attn.flash_attn_varlen_func(
105
+ q=q,
106
+ k=k,
107
+ v=v,
108
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
109
+ .cumsum(0, dtype=torch.int32)
110
+ .to(q.device, non_blocking=True),
111
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
112
+ .cumsum(0, dtype=torch.int32)
113
+ .to(q.device, non_blocking=True),
114
+ max_seqlen_q=lq,
115
+ max_seqlen_k=lk,
116
+ dropout_p=dropout_p,
117
+ softmax_scale=softmax_scale,
118
+ causal=causal,
119
+ window_size=window_size,
120
+ deterministic=deterministic,
121
+ ).unflatten(0, (b, lq))
122
+ else:
123
+ q = q.unsqueeze(0).transpose(1, 2).to(dtype)
124
+ k = k.unsqueeze(0).transpose(1, 2).to(dtype)
125
+ v = v.unsqueeze(0).transpose(1, 2).to(dtype)
126
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
127
+ x = x.transpose(1, 2).contiguous()
128
+
129
+ # output
130
+ return x.type(out_dtype)
131
+
132
+
133
+ def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
134
+ b, lq, lk = q.size(0), q.size(1), k.size(1)
135
+ if q_lens is None:
136
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
137
+ if k_lens is None:
138
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
139
+ attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
140
+ for i in range(b):
141
+ q_len, k_len = q_lens[i], k_lens[i]
142
+ attn_mask[i, q_len:, :] = True
143
+ attn_mask[i, :, k_len:] = True
144
+
145
+ if causal:
146
+ causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
147
+ attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
148
+
149
+ attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
150
+ return attn_mask
151
+
152
+
153
+ def attention(
154
+ q,
155
+ k,
156
+ v,
157
+ q_lens=None,
158
+ k_lens=None,
159
+ dropout_p=0.0,
160
+ softmax_scale=None,
161
+ q_scale=None,
162
+ causal=False,
163
+ window_size=(-1, -1),
164
+ deterministic=False,
165
+ dtype=torch.bfloat16,
166
+ fa_version=None,
167
+ ):
168
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
169
+ return flash_attention(
170
+ q=q,
171
+ k=k,
172
+ v=v,
173
+ q_lens=q_lens,
174
+ k_lens=k_lens,
175
+ dropout_p=dropout_p,
176
+ softmax_scale=softmax_scale,
177
+ q_scale=q_scale,
178
+ causal=causal,
179
+ window_size=window_size,
180
+ deterministic=deterministic,
181
+ dtype=dtype,
182
+ version=fa_version,
183
+ )
184
+ else:
185
+ if q_lens is not None or k_lens is not None:
186
+ warnings.warn(
187
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
188
+ )
189
+ attn_mask = None
190
+
191
+ q = q.transpose(1, 2).to(dtype)
192
+ k = k.transpose(1, 2).to(dtype)
193
+ v = v.transpose(1, 2).to(dtype)
194
+
195
+ out = torch.nn.functional.scaled_dot_product_attention(
196
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
197
+ )
198
+
199
+ out = out.transpose(1, 2).contiguous()
200
+ return out
@@ -0,0 +1,431 @@
1
+ import math
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Dict, Tuple, Optional
7
+ from einops import rearrange
8
+
9
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10
+ from diffsynth_engine.models.utils import no_init_weights
11
+ from diffsynth_engine.utils.constants import (
12
+ WAN_DIT_1_3B_T2V_CONFIG_FILE,
13
+ WAN_DIT_14B_I2V_CONFIG_FILE,
14
+ WAN_DIT_14B_T2V_CONFIG_FILE,
15
+ )
16
+
17
+
18
+ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int):
19
+ q, k, v = (rearrange(t, "b s (n d) -> b n s d ", n=num_heads) for t in (q, k, v))
20
+ x = F.scaled_dot_product_attention(q, k, v)
21
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
22
+ return x
23
+
24
+
25
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
26
+ return x * (1 + scale) + shift
27
+
28
+
29
+ def sinusoidal_embedding_1d(dim, position):
30
+ sinusoid = torch.outer(
31
+ position.type(torch.float64),
32
+ torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)),
33
+ )
34
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
35
+ return x.to(position.dtype)
36
+
37
+
38
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
39
+ # 3d rope precompute
40
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
41
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
42
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
43
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
44
+
45
+
46
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
47
+ # 1d rope precompute
48
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
49
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
50
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
51
+ return freqs_cis
52
+
53
+
54
+ def rope_apply(x, freqs, num_heads):
55
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
56
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
57
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
58
+ return x_out.to(x.dtype)
59
+
60
+
61
+ class RMSNorm(nn.Module):
62
+ def __init__(
63
+ self,
64
+ dim,
65
+ eps=1e-5,
66
+ device: str = "cuda:0",
67
+ dtype: torch.dtype = torch.bfloat16,
68
+ ):
69
+ super().__init__()
70
+ self.eps = eps
71
+ self.dim = dim
72
+ self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
73
+
74
+ def norm(self, x):
75
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
76
+
77
+ def forward(self, x):
78
+ return self.norm(x.float()).to(x.dtype) * self.weight
79
+
80
+
81
+ class SelfAttention(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ num_heads: int,
86
+ eps: float = 1e-6,
87
+ device: str = "cuda:0",
88
+ dtype: torch.dtype = torch.bfloat16,
89
+ ):
90
+ super().__init__()
91
+ self.dim = dim
92
+ self.head_dim = dim // num_heads
93
+
94
+ self.q = nn.Linear(dim, dim, device=device, dtype=dtype)
95
+ self.k = nn.Linear(dim, dim, device=device, dtype=dtype)
96
+ self.v = nn.Linear(dim, dim, device=device, dtype=dtype)
97
+ self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
98
+ self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
99
+ self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
100
+
101
+ def forward(self, x, freqs):
102
+ q = self.norm_q(self.q(x))
103
+ k = self.norm_k(self.k(x))
104
+ v = self.v(x)
105
+ num_heads = q.shape[2] // self.head_dim
106
+ x = attention(q=rope_apply(q, freqs, num_heads), k=rope_apply(k, freqs, num_heads), v=v, num_heads=num_heads)
107
+ return self.o(x)
108
+
109
+
110
+ class CrossAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ num_heads: int,
115
+ eps: float = 1e-6,
116
+ has_image_input: bool = False,
117
+ device: str = "cuda:0",
118
+ dtype: torch.dtype = torch.bfloat16,
119
+ ):
120
+ super().__init__()
121
+ self.dim = dim
122
+ self.head_dim = dim // num_heads
123
+
124
+ self.q = nn.Linear(dim, dim, device=device, dtype=dtype)
125
+ self.k = nn.Linear(dim, dim, device=device, dtype=dtype)
126
+ self.v = nn.Linear(dim, dim, device=device, dtype=dtype)
127
+ self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
128
+ self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
129
+ self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
130
+ self.has_image_input = has_image_input
131
+ if has_image_input:
132
+ self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
133
+ self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
134
+ self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
135
+
136
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
137
+ if self.has_image_input:
138
+ img = y[:, :257]
139
+ ctx = y[:, 257:]
140
+ else:
141
+ ctx = y
142
+ q = self.norm_q(self.q(x))
143
+ k = self.norm_k(self.k(ctx))
144
+ v = self.v(ctx)
145
+ num_heads = q.shape[2] // self.head_dim
146
+ x = attention(q, k, v, num_heads=num_heads)
147
+ if self.has_image_input:
148
+ k_img = self.norm_k_img(self.k_img(img))
149
+ v_img = self.v_img(img)
150
+ y = attention(q, k_img, v_img, num_heads=num_heads)
151
+ x = x + y
152
+ return self.o(x)
153
+
154
+
155
+ class DiTBlock(nn.Module):
156
+ def __init__(
157
+ self,
158
+ has_image_input: bool,
159
+ dim: int,
160
+ num_heads: int,
161
+ ffn_dim: int,
162
+ eps: float = 1e-6,
163
+ device: str = "cuda:0",
164
+ dtype: torch.dtype = torch.bfloat16,
165
+ ):
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.num_heads = num_heads
169
+ self.ffn_dim = ffn_dim
170
+
171
+ self.self_attn = SelfAttention(dim, num_heads, eps, device=device, dtype=dtype)
172
+ self.cross_attn = CrossAttention(
173
+ dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
174
+ )
175
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
176
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
177
+ self.norm3 = nn.LayerNorm(dim, eps=eps, device=device, dtype=dtype)
178
+ self.ffn = nn.Sequential(
179
+ nn.Linear(dim, ffn_dim, device=device, dtype=dtype),
180
+ nn.GELU(approximate="tanh"),
181
+ nn.Linear(ffn_dim, dim, device=device, dtype=dtype),
182
+ )
183
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
184
+
185
+ def forward(self, x, context, t_mod, freqs):
186
+ # msa: multi-head self-attention mlp: multi-layer perceptron
187
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + t_mod).chunk(6, dim=1)
188
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
189
+ x = x + gate_msa * self.self_attn(input_x, freqs)
190
+ x = x + self.cross_attn(self.norm3(x), context)
191
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
192
+ x = x + gate_mlp * self.ffn(input_x)
193
+ return x
194
+
195
+
196
+ class MLP(torch.nn.Module):
197
+ def __init__(
198
+ self,
199
+ in_dim,
200
+ out_dim,
201
+ device: str = "cuda:0",
202
+ dtype: torch.dtype = torch.bfloat16,
203
+ ):
204
+ super().__init__()
205
+ self.proj = torch.nn.Sequential(
206
+ nn.LayerNorm(in_dim, device=device, dtype=dtype),
207
+ nn.Linear(in_dim, in_dim, device=device, dtype=dtype),
208
+ nn.GELU(),
209
+ nn.Linear(in_dim, out_dim, device=device, dtype=dtype),
210
+ nn.LayerNorm(out_dim, device=device, dtype=dtype),
211
+ )
212
+
213
+ def forward(self, x):
214
+ return self.proj(x)
215
+
216
+
217
+ class Head(nn.Module):
218
+ def __init__(
219
+ self,
220
+ dim: int,
221
+ out_dim: int,
222
+ patch_size: Tuple[int, int, int],
223
+ eps: float,
224
+ device: str = "cuda:0",
225
+ dtype: torch.dtype = torch.bfloat16,
226
+ ):
227
+ super().__init__()
228
+ self.dim = dim
229
+ self.patch_size = patch_size
230
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
231
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size), device=device, dtype=dtype)
232
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
233
+
234
+ def forward(self, x, t_mod):
235
+ shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
236
+ x = self.head(self.norm(x) * (1 + scale) + shift)
237
+ return x
238
+
239
+
240
+ class WanDiTStateDictConverter(StateDictConverter):
241
+ def convert(self, state_dict):
242
+ return state_dict
243
+
244
+
245
+ class WanDiT(PreTrainedModel):
246
+ converter = WanDiTStateDictConverter()
247
+
248
+ def __init__(
249
+ self,
250
+ dim: int,
251
+ in_dim: int,
252
+ ffn_dim: int,
253
+ out_dim: int,
254
+ text_dim: int,
255
+ freq_dim: int,
256
+ eps: float,
257
+ patch_size: Tuple[int, int, int],
258
+ num_heads: int,
259
+ num_layers: int,
260
+ has_image_input: bool,
261
+ device: str = "cuda:0",
262
+ dtype: torch.dtype = torch.bfloat16,
263
+ ):
264
+ super().__init__()
265
+
266
+ self.dim = dim
267
+ self.freq_dim = freq_dim
268
+ self.has_image_input = has_image_input
269
+ self.patch_size = patch_size
270
+
271
+ self.patch_embedding = nn.Conv3d(
272
+ in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
273
+ )
274
+ self.text_embedding = nn.Sequential(
275
+ nn.Linear(text_dim, dim, device=device, dtype=dtype),
276
+ nn.GELU(approximate="tanh"),
277
+ nn.Linear(dim, dim, device=device, dtype=dtype),
278
+ )
279
+ self.time_embedding = nn.Sequential(
280
+ nn.Linear(freq_dim, dim, device=device, dtype=dtype),
281
+ nn.SiLU(),
282
+ nn.Linear(dim, dim, device=device, dtype=dtype),
283
+ )
284
+ self.time_projection = nn.Sequential(
285
+ nn.SiLU(),
286
+ nn.Linear(dim, dim * 6, device=device, dtype=dtype),
287
+ )
288
+ self.blocks = nn.ModuleList(
289
+ [
290
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, device=device, dtype=dtype)
291
+ for _ in range(num_layers)
292
+ ]
293
+ )
294
+ self.head = Head(dim, out_dim, patch_size, eps, device=device, dtype=dtype)
295
+ head_dim = dim // num_heads
296
+ self.freqs = precompute_freqs_cis_3d(head_dim)
297
+
298
+ if has_image_input:
299
+ self.img_emb = MLP(1280, dim, device=device, dtype=dtype) # clip_feature_dim = 1280
300
+
301
+ def patchify(self, x: torch.Tensor):
302
+ x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
303
+ grid_size = x.shape[2:]
304
+ x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
305
+ return x, grid_size # x, grid_size: (f, h, w)
306
+
307
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
308
+ return rearrange(
309
+ x,
310
+ "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
311
+ f=grid_size[0],
312
+ h=grid_size[1],
313
+ w=grid_size[2],
314
+ x=self.patch_size[0],
315
+ y=self.patch_size[1],
316
+ z=self.patch_size[2],
317
+ )
318
+
319
+ def forward(
320
+ self,
321
+ x: torch.Tensor,
322
+ context: torch.Tensor,
323
+ timestep: torch.Tensor,
324
+ clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
325
+ y: Optional[torch.Tensor] = None, # vae_encoder(img)
326
+ ):
327
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
328
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
329
+ context = self.text_embedding(context)
330
+ if self.has_image_input:
331
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
332
+ clip_embdding = self.img_emb(clip_feature)
333
+ context = torch.cat([clip_embdding, context], dim=1) # (b, s1 + s2, d)
334
+ x, (f, h, w) = self.patchify(x)
335
+ freqs = (
336
+ torch.cat(
337
+ [
338
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
339
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
340
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
341
+ ],
342
+ dim=-1,
343
+ )
344
+ .reshape(f * h * w, 1, -1)
345
+ .to(x.device)
346
+ )
347
+ for block in self.blocks:
348
+ x = block(x, context, t_mod, freqs)
349
+ x = self.head(x, t)
350
+ x = self.unpatchify(x, (f, h, w))
351
+ return x
352
+
353
+ @classmethod
354
+ def from_state_dict(
355
+ cls,
356
+ state_dict: Dict[str, torch.Tensor],
357
+ device: str,
358
+ dtype: torch.dtype,
359
+ model_type: str = "1.3b-t2v",
360
+ ):
361
+ if model_type == "1.3b-t2v":
362
+ config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
363
+ elif model_type == "14b-t2v":
364
+ config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
365
+ elif model_type == "14b-i2v":
366
+ config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
367
+ else:
368
+ raise ValueError(f"Unsupported model type: {model_type}")
369
+ with no_init_weights():
370
+ model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype)
371
+ model.load_state_dict(state_dict, assign=True)
372
+ model.to(device=device, dtype=dtype, non_blocking=True)
373
+ return model
374
+
375
+ def get_tp_plan(self):
376
+ from torch.distributed.tensor.parallel import (
377
+ ColwiseParallel,
378
+ RowwiseParallel,
379
+ SequenceParallel,
380
+ PrepareModuleOutput,
381
+ )
382
+ from torch.distributed.tensor import Replicate, Shard
383
+
384
+ tp_plan = {
385
+ "text_embedding.0": ColwiseParallel(),
386
+ "text_embedding.2": RowwiseParallel(),
387
+ "time_embedding.0": ColwiseParallel(),
388
+ "time_embedding.2": RowwiseParallel(),
389
+ "time_projection.1": ColwiseParallel(output_layouts=Replicate()),
390
+ }
391
+ for idx in range(len(self.blocks)):
392
+ tp_plan.update(
393
+ {
394
+ f"blocks.{idx}.norm1": SequenceParallel(use_local_output=True),
395
+ f"blocks.{idx}.norm2": SequenceParallel(use_local_output=True),
396
+ f"blocks.{idx}.norm3": SequenceParallel(use_local_output=True),
397
+ f"blocks.{idx}.ffn.0": ColwiseParallel(),
398
+ f"blocks.{idx}.ffn.2": RowwiseParallel(),
399
+ f"blocks.{idx}.self_attn.q": ColwiseParallel(output_layouts=Replicate()),
400
+ f"blocks.{idx}.self_attn.k": ColwiseParallel(output_layouts=Replicate()),
401
+ f"blocks.{idx}.self_attn.v": ColwiseParallel(),
402
+ f"blocks.{idx}.self_attn.o": RowwiseParallel(),
403
+ f"blocks.{idx}.self_attn.norm_q": PrepareModuleOutput(
404
+ output_layouts=Replicate(),
405
+ desired_output_layouts=Shard(-1),
406
+ ),
407
+ f"blocks.{idx}.self_attn.norm_k": PrepareModuleOutput(
408
+ output_layouts=Replicate(),
409
+ desired_output_layouts=Shard(-1),
410
+ ),
411
+ f"blocks.{idx}.cross_attn.q": ColwiseParallel(output_layouts=Replicate()),
412
+ f"blocks.{idx}.cross_attn.k": ColwiseParallel(output_layouts=Replicate()),
413
+ f"blocks.{idx}.cross_attn.v": ColwiseParallel(),
414
+ f"blocks.{idx}.cross_attn.o": RowwiseParallel(),
415
+ f"blocks.{idx}.cross_attn.norm_q": PrepareModuleOutput(
416
+ output_layouts=Replicate(),
417
+ desired_output_layouts=Shard(-1),
418
+ ),
419
+ f"blocks.{idx}.cross_attn.norm_k": PrepareModuleOutput(
420
+ output_layouts=Replicate(),
421
+ desired_output_layouts=Shard(-1),
422
+ ),
423
+ f"blocks.{idx}.cross_attn.k_img": ColwiseParallel(output_layouts=Replicate()),
424
+ f"blocks.{idx}.cross_attn.v_img": ColwiseParallel(),
425
+ f"blocks.{idx}.cross_attn.norm_k_img": PrepareModuleOutput(
426
+ output_layouts=Replicate(),
427
+ desired_output_layouts=Shard(-1),
428
+ ),
429
+ }
430
+ )
431
+ return tp_plan