diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from contextlib import contextmanager
4
+
5
+
6
+ # mofified from transformers.modeling_utils
7
+ TORCH_INIT_FUNCTIONS = {
8
+ "uniform_": nn.init.uniform_,
9
+ "normal_": nn.init.normal_,
10
+ "trunc_normal_": nn.init.trunc_normal_,
11
+ "constant_": nn.init.constant_,
12
+ "xavier_uniform_": nn.init.xavier_uniform_,
13
+ "xavier_normal_": nn.init.xavier_normal_,
14
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
15
+ "kaiming_normal_": nn.init.kaiming_normal_,
16
+ "uniform": nn.init.uniform,
17
+ "normal": nn.init.normal,
18
+ "xavier_uniform": nn.init.xavier_uniform,
19
+ "xavier_normal": nn.init.xavier_normal,
20
+ "kaiming_uniform": nn.init.kaiming_uniform,
21
+ "kaiming_normal": nn.init.kaiming_normal,
22
+ }
23
+
24
+ _init_weights = True
25
+
26
+
27
+ @contextmanager
28
+ def no_init_weights():
29
+ """
30
+ Context manager to globally disable weight initialization to speed up loading large models.
31
+ """
32
+ global _init_weights
33
+ old_init_weights = _init_weights
34
+
35
+ def _skip_init(*args, **kwargs):
36
+ pass
37
+
38
+ _init_weights = False
39
+ # Save the original initialization functions
40
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
41
+ setattr(torch.nn.init, name, _skip_init)
42
+ try:
43
+ yield
44
+ finally:
45
+ _init_weights = old_init_weights
46
+ # Restore the original initialization functions
47
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
48
+ setattr(torch.nn.init, name, init_func)
49
+
50
+
51
+ def zero_module(module: nn.Module):
52
+ for p in module.parameters():
53
+ nn.init.zeros_(p)
54
+ return module
File without changes
@@ -0,0 +1,497 @@
1
+ import math
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.distributed as dist
6
+ from typing import Tuple, Optional
7
+ from einops import rearrange
8
+
9
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10
+ from diffsynth_engine.models.basic.attention import attention, long_context_attention
11
+ from diffsynth_engine.models.utils import no_init_weights
12
+ from diffsynth_engine.utils.constants import (
13
+ WAN_DIT_1_3B_T2V_CONFIG_FILE,
14
+ WAN_DIT_14B_I2V_CONFIG_FILE,
15
+ WAN_DIT_14B_T2V_CONFIG_FILE,
16
+ )
17
+ from diffsynth_engine.utils.gguf import gguf_inference
18
+ from diffsynth_engine.utils.parallel import (
19
+ get_sp_group,
20
+ get_sp_world_size,
21
+ get_sp_rank,
22
+ )
23
+
24
+
25
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
26
+ return x * (1 + scale) + shift
27
+
28
+
29
+ def sinusoidal_embedding_1d(dim, position):
30
+ sinusoid = torch.outer(
31
+ position.type(torch.float64),
32
+ torch.pow(10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2)),
33
+ )
34
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
35
+ return x.to(position.dtype)
36
+
37
+
38
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
39
+ # 3d rope precompute
40
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
41
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
42
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
43
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
44
+
45
+
46
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
47
+ # 1d rope precompute
48
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
49
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
50
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
51
+ return freqs_cis
52
+
53
+
54
+ def rope_apply(x, freqs):
55
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
56
+ x_out = torch.view_as_real(x_out * freqs)
57
+ return x_out.to(x.dtype).flatten(3)
58
+
59
+
60
+ class RMSNorm(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ eps=1e-5,
65
+ device: str = "cuda:0",
66
+ dtype: torch.dtype = torch.bfloat16,
67
+ ):
68
+ super().__init__()
69
+ self.eps = eps
70
+ self.dim = dim
71
+ self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
72
+
73
+ def norm(self, x):
74
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
75
+
76
+ def forward(self, x):
77
+ return self.norm(x.float()).to(x.dtype) * self.weight
78
+
79
+
80
+ class SelfAttention(nn.Module):
81
+ def __init__(
82
+ self,
83
+ dim: int,
84
+ num_heads: int,
85
+ eps: float = 1e-6,
86
+ attn_impl: Optional[str] = None,
87
+ device: str = "cuda:0",
88
+ dtype: torch.dtype = torch.bfloat16,
89
+ ):
90
+ super().__init__()
91
+ self.dim = dim
92
+ self.head_dim = dim // num_heads
93
+ self.q = nn.Linear(dim, dim, device=device, dtype=dtype)
94
+ self.k = nn.Linear(dim, dim, device=device, dtype=dtype)
95
+ self.v = nn.Linear(dim, dim, device=device, dtype=dtype)
96
+ self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
97
+ self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
98
+ self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
99
+ self.attn_impl = attn_impl
100
+
101
+ def forward(self, x, freqs):
102
+ q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
103
+ num_heads = q.shape[2] // self.head_dim
104
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
105
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
106
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
107
+ if getattr(self, "use_usp", False):
108
+ x = long_context_attention(
109
+ q=rope_apply(q, freqs),
110
+ k=rope_apply(k, freqs),
111
+ v=v,
112
+ attn_impl=self.attn_impl,
113
+ )
114
+ else:
115
+ x = attention(
116
+ q=rope_apply(q, freqs),
117
+ k=rope_apply(k, freqs),
118
+ v=v,
119
+ attn_impl=self.attn_impl,
120
+ )
121
+ x = x.flatten(2)
122
+ return self.o(x)
123
+
124
+
125
+ class CrossAttention(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim: int,
129
+ num_heads: int,
130
+ eps: float = 1e-6,
131
+ has_image_input: bool = False,
132
+ attn_impl: Optional[str] = None,
133
+ device: str = "cuda:0",
134
+ dtype: torch.dtype = torch.bfloat16,
135
+ ):
136
+ super().__init__()
137
+ self.dim = dim
138
+ self.head_dim = dim // num_heads
139
+
140
+ self.q = nn.Linear(dim, dim, device=device, dtype=dtype)
141
+ self.k = nn.Linear(dim, dim, device=device, dtype=dtype)
142
+ self.v = nn.Linear(dim, dim, device=device, dtype=dtype)
143
+ self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
144
+ self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
145
+ self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
146
+ self.has_image_input = has_image_input
147
+ if has_image_input:
148
+ self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
149
+ self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
150
+ self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
151
+ self.attn_impl = attn_impl
152
+
153
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
154
+ if self.has_image_input:
155
+ img = y[:, :257]
156
+ ctx = y[:, 257:]
157
+ else:
158
+ ctx = y
159
+ q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(ctx)), self.v(ctx)
160
+ num_heads = q.shape[2] // self.head_dim
161
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
162
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
163
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
164
+
165
+ x = attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
166
+ if self.has_image_input:
167
+ k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
168
+ k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
169
+ y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
170
+ x = x + y
171
+ return self.o(x)
172
+
173
+
174
+ class DiTBlock(nn.Module):
175
+ def __init__(
176
+ self,
177
+ has_image_input: bool,
178
+ dim: int,
179
+ num_heads: int,
180
+ ffn_dim: int,
181
+ eps: float = 1e-6,
182
+ attn_impl: Optional[str] = None,
183
+ device: str = "cuda:0",
184
+ dtype: torch.dtype = torch.bfloat16,
185
+ ):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.num_heads = num_heads
189
+ self.ffn_dim = ffn_dim
190
+ self.self_attn = SelfAttention(dim, num_heads, eps, attn_impl=attn_impl, device=device, dtype=dtype)
191
+ self.cross_attn = CrossAttention(
192
+ dim, num_heads, eps, has_image_input=has_image_input, attn_impl=attn_impl, device=device, dtype=dtype
193
+ )
194
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
195
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
196
+ self.norm3 = nn.LayerNorm(dim, eps=eps, device=device, dtype=dtype)
197
+ self.ffn = nn.Sequential(
198
+ nn.Linear(dim, ffn_dim, device=device, dtype=dtype),
199
+ nn.GELU(approximate="tanh"),
200
+ nn.Linear(ffn_dim, dim, device=device, dtype=dtype),
201
+ )
202
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
203
+
204
+ def forward(self, x, context, t_mod, freqs):
205
+ # msa: multi-head self-attention mlp: multi-layer perceptron
206
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + t_mod).chunk(6, dim=1)
207
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
208
+ x = x + gate_msa * self.self_attn(input_x, freqs)
209
+ x = x + self.cross_attn(self.norm3(x), context)
210
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
211
+ x = x + gate_mlp * self.ffn(input_x)
212
+ return x
213
+
214
+
215
+ class MLP(torch.nn.Module):
216
+ def __init__(
217
+ self,
218
+ in_dim,
219
+ out_dim,
220
+ device: str = "cuda:0",
221
+ dtype: torch.dtype = torch.bfloat16,
222
+ ):
223
+ super().__init__()
224
+ self.proj = torch.nn.Sequential(
225
+ nn.LayerNorm(in_dim, device=device, dtype=dtype),
226
+ nn.Linear(in_dim, in_dim, device=device, dtype=dtype),
227
+ nn.GELU(),
228
+ nn.Linear(in_dim, out_dim, device=device, dtype=dtype),
229
+ nn.LayerNorm(out_dim, device=device, dtype=dtype),
230
+ )
231
+
232
+ def forward(self, x):
233
+ return self.proj(x)
234
+
235
+
236
+ class Head(nn.Module):
237
+ def __init__(
238
+ self,
239
+ dim: int,
240
+ out_dim: int,
241
+ patch_size: Tuple[int, int, int],
242
+ eps: float,
243
+ device: str = "cuda:0",
244
+ dtype: torch.dtype = torch.bfloat16,
245
+ ):
246
+ super().__init__()
247
+ self.dim = dim
248
+ self.patch_size = patch_size
249
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
250
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size), device=device, dtype=dtype)
251
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim, device=device, dtype=dtype) / dim**0.5)
252
+
253
+ def forward(self, x, t_mod):
254
+ shift, scale = (self.modulation + t_mod).chunk(2, dim=1)
255
+ x = self.head(self.norm(x) * (1 + scale) + shift)
256
+ return x
257
+
258
+
259
+ class WanDiTStateDictConverter(StateDictConverter):
260
+ def convert(self, state_dict):
261
+ return state_dict
262
+
263
+
264
+ class WanDiT(PreTrainedModel):
265
+ converter = WanDiTStateDictConverter()
266
+
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ in_dim: int,
271
+ ffn_dim: int,
272
+ out_dim: int,
273
+ text_dim: int,
274
+ freq_dim: int,
275
+ eps: float,
276
+ patch_size: Tuple[int, int, int],
277
+ num_heads: int,
278
+ num_layers: int,
279
+ has_image_input: bool,
280
+ attn_impl: Optional[str] = None,
281
+ use_usp: bool = False,
282
+ device: str = "cpu",
283
+ dtype: torch.dtype = torch.bfloat16,
284
+ ):
285
+ super().__init__()
286
+
287
+ self.dim = dim
288
+ self.freq_dim = freq_dim
289
+ self.has_image_input = has_image_input
290
+ self.patch_size = patch_size
291
+
292
+ self.patch_embedding = nn.Conv3d(
293
+ in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
294
+ )
295
+ self.text_embedding = nn.Sequential(
296
+ nn.Linear(text_dim, dim, device=device, dtype=dtype),
297
+ nn.GELU(approximate="tanh"),
298
+ nn.Linear(dim, dim, device=device, dtype=dtype),
299
+ )
300
+ self.time_embedding = nn.Sequential(
301
+ nn.Linear(freq_dim, dim, device=device, dtype=dtype),
302
+ nn.SiLU(),
303
+ nn.Linear(dim, dim, device=device, dtype=dtype),
304
+ )
305
+
306
+ self.time_projection = nn.Sequential(
307
+ nn.SiLU(),
308
+ nn.Linear(dim, dim * 6, device=device, dtype=dtype),
309
+ )
310
+ self.blocks = nn.ModuleList(
311
+ [
312
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_impl, device=device, dtype=dtype)
313
+ for _ in range(num_layers)
314
+ ]
315
+ )
316
+ self.head = Head(dim, out_dim, patch_size, eps, device=device, dtype=dtype)
317
+
318
+ head_dim = dim // num_heads
319
+ self.freqs = precompute_freqs_cis_3d(head_dim)
320
+
321
+ if has_image_input:
322
+ self.img_emb = MLP(1280, dim, device=device, dtype=dtype) # clip_feature_dim = 1280
323
+
324
+ if use_usp:
325
+ setattr(self, "use_usp", True)
326
+ for block in self.blocks:
327
+ setattr(block.self_attn, "use_usp", True)
328
+
329
+ def patchify(self, x: torch.Tensor):
330
+ x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
331
+ grid_size = x.shape[2:]
332
+ x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
333
+ return x, grid_size # x, grid_size: (f, h, w)
334
+
335
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
336
+ return rearrange(
337
+ x,
338
+ "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
339
+ f=grid_size[0],
340
+ h=grid_size[1],
341
+ w=grid_size[2],
342
+ x=self.patch_size[0],
343
+ y=self.patch_size[1],
344
+ z=self.patch_size[2],
345
+ )
346
+
347
+ def forward(
348
+ self,
349
+ x: torch.Tensor,
350
+ context: torch.Tensor,
351
+ timestep: torch.Tensor,
352
+ clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
353
+ y: Optional[torch.Tensor] = None, # vae_encoder(img)
354
+ ):
355
+ with gguf_inference():
356
+ t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
357
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
358
+ context = self.text_embedding(context)
359
+ if self.has_image_input:
360
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
361
+ clip_embdding = self.img_emb(clip_feature)
362
+ context = torch.cat([clip_embdding, context], dim=1) # (b, s1 + s2, d)
363
+ x, (f, h, w) = self.patchify(x)
364
+ freqs = (
365
+ torch.cat(
366
+ [
367
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
368
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
369
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
370
+ ],
371
+ dim=-1,
372
+ )
373
+ .reshape(f * h * w, 1, -1)
374
+ .to(x.device)
375
+ )
376
+ if getattr(self, "use_usp", False):
377
+ s, p = x.size(1), get_sp_world_size() # (sequence_length, parallelism)
378
+ split_size = [s // p + 1 if i < s % p else s // p for i in range(p)]
379
+ x = torch.split(x, split_size, dim=1)[get_sp_rank()]
380
+ freqs = torch.split(freqs, split_size, dim=0)[get_sp_rank()]
381
+
382
+ for block in self.blocks:
383
+ x = block(x, context, t_mod, freqs)
384
+ x = self.head(x, t)
385
+
386
+ if getattr(self, "use_usp", False):
387
+ b, d = x.size(0), x.size(2) # (batch_size, out_dim)
388
+ xs = [torch.zeros((b, s, d), dtype=x.dtype, device=x.device) for s in split_size]
389
+ dist.all_gather(xs, x, group=get_sp_group())
390
+ x = torch.concat(xs, dim=1)
391
+ x = self.unpatchify(x, (f, h, w))
392
+ return x
393
+
394
+ @classmethod
395
+ def from_state_dict(
396
+ cls,
397
+ state_dict,
398
+ device,
399
+ dtype,
400
+ model_type="1.3b-t2v",
401
+ attn_impl: Optional[str] = None,
402
+ use_usp=False,
403
+ assign=True,
404
+ ):
405
+ if model_type == "1.3b-t2v":
406
+ config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
407
+ elif model_type == "14b-t2v":
408
+ config = json.load(open(WAN_DIT_14B_T2V_CONFIG_FILE, "r"))
409
+ elif model_type == "14b-i2v":
410
+ config = json.load(open(WAN_DIT_14B_I2V_CONFIG_FILE, "r"))
411
+ else:
412
+ raise ValueError(f"Unsupported model type: {model_type}")
413
+ with no_init_weights():
414
+ model = torch.nn.utils.skip_init(
415
+ cls, **config, device=device, dtype=dtype, attn_impl=attn_impl, use_usp=use_usp
416
+ )
417
+ model = model.requires_grad_(False)
418
+ model.load_state_dict(state_dict, assign=assign)
419
+ model.to(device=device, dtype=dtype)
420
+ return model
421
+
422
+ def get_tp_plan(self):
423
+ from torch.distributed.tensor.parallel import (
424
+ ColwiseParallel,
425
+ RowwiseParallel,
426
+ PrepareModuleInput,
427
+ PrepareModuleOutput,
428
+ )
429
+ from torch.distributed.tensor import Replicate, Shard
430
+
431
+ tp_plan = {
432
+ "text_embedding.0": ColwiseParallel(),
433
+ "text_embedding.2": RowwiseParallel(),
434
+ "time_embedding.0": ColwiseParallel(),
435
+ "time_embedding.2": RowwiseParallel(),
436
+ "time_projection.1": ColwiseParallel(output_layouts=Replicate()),
437
+ "blocks.0": PrepareModuleInput(
438
+ input_layouts=(Replicate(), None, None, None),
439
+ desired_input_layouts=(Shard(1), None, None, None), # sequence parallel
440
+ use_local_output=True,
441
+ ),
442
+ "head": PrepareModuleOutput(
443
+ output_layouts=Shard(1),
444
+ desired_output_layouts=Replicate(),
445
+ use_local_output=True,
446
+ ),
447
+ }
448
+ for idx in range(len(self.blocks)):
449
+ tp_plan.update(
450
+ {
451
+ f"blocks.{idx}.self_attn": PrepareModuleInput(
452
+ input_layouts=(Shard(1), None),
453
+ desired_input_layouts=(Replicate(), None),
454
+ ),
455
+ f"blocks.{idx}.self_attn.q": ColwiseParallel(output_layouts=Shard(1)),
456
+ f"blocks.{idx}.self_attn.k": ColwiseParallel(output_layouts=Shard(1)),
457
+ f"blocks.{idx}.self_attn.v": ColwiseParallel(),
458
+ f"blocks.{idx}.self_attn.o": RowwiseParallel(output_layouts=Shard(1)),
459
+ f"blocks.{idx}.self_attn.norm_q": PrepareModuleOutput(
460
+ output_layouts=Shard(1),
461
+ desired_output_layouts=Shard(-1),
462
+ ),
463
+ f"blocks.{idx}.self_attn.norm_k": PrepareModuleOutput(
464
+ output_layouts=Shard(1),
465
+ desired_output_layouts=Shard(-1),
466
+ ),
467
+ f"blocks.{idx}.cross_attn": PrepareModuleInput(
468
+ input_layouts=(Shard(1), None),
469
+ desired_input_layouts=(Replicate(), None),
470
+ ),
471
+ f"blocks.{idx}.cross_attn.q": ColwiseParallel(output_layouts=Shard(1)),
472
+ f"blocks.{idx}.cross_attn.k": ColwiseParallel(output_layouts=Shard(1)),
473
+ f"blocks.{idx}.cross_attn.v": ColwiseParallel(),
474
+ f"blocks.{idx}.cross_attn.o": RowwiseParallel(output_layouts=Shard(1)),
475
+ f"blocks.{idx}.cross_attn.norm_q": PrepareModuleOutput(
476
+ output_layouts=Shard(1),
477
+ desired_output_layouts=Shard(-1),
478
+ ),
479
+ f"blocks.{idx}.cross_attn.norm_k": PrepareModuleOutput(
480
+ output_layouts=Shard(1),
481
+ desired_output_layouts=Shard(-1),
482
+ ),
483
+ f"blocks.{idx}.cross_attn.k_img": ColwiseParallel(output_layouts=Shard(1)),
484
+ f"blocks.{idx}.cross_attn.v_img": ColwiseParallel(),
485
+ f"blocks.{idx}.cross_attn.norm_k_img": PrepareModuleOutput(
486
+ output_layouts=Shard(1),
487
+ desired_output_layouts=Shard(-1),
488
+ ),
489
+ f"blocks.{idx}.ffn": PrepareModuleInput(
490
+ input_layouts=(Shard(1),),
491
+ desired_input_layouts=(Replicate(),),
492
+ ),
493
+ f"blocks.{idx}.ffn.0": ColwiseParallel(),
494
+ f"blocks.{idx}.ffn.2": RowwiseParallel(output_layouts=Shard(1)),
495
+ }
496
+ )
497
+ return tp_plan