diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,494 @@
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as T
7
+
8
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
9
+ from diffsynth_engine.models.utils import no_init_weights
10
+ from diffsynth_engine.models.basic.attention import attention
11
+
12
+
13
+ def pos_interpolate(pos, seq_len):
14
+ if pos.size(1) == seq_len:
15
+ return pos
16
+ else:
17
+ src_grid = int(math.sqrt(pos.size(1)))
18
+ tar_grid = int(math.sqrt(seq_len))
19
+ n = pos.size(1) - src_grid * src_grid
20
+ return torch.cat(
21
+ [
22
+ pos[:, :n],
23
+ F.interpolate(
24
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
25
+ size=(tar_grid, tar_grid),
26
+ mode="bicubic",
27
+ align_corners=False,
28
+ )
29
+ .flatten(2)
30
+ .transpose(1, 2),
31
+ ],
32
+ dim=1,
33
+ )
34
+
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+
41
+ class SelfAttention(nn.Module):
42
+ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
43
+ assert dim % num_heads == 0
44
+ super().__init__()
45
+ self.dim = dim
46
+ self.num_heads = num_heads
47
+ self.head_dim = dim // num_heads
48
+ self.causal = causal
49
+ self.attn_dropout = attn_dropout
50
+ self.proj_dropout = proj_dropout
51
+
52
+ # layers
53
+ self.to_qkv = nn.Linear(dim, dim * 3)
54
+ self.proj = nn.Linear(dim, dim)
55
+
56
+ def forward(self, x):
57
+ """
58
+ x: [B, L, C].
59
+ """
60
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
61
+
62
+ # compute query, key, value
63
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
64
+
65
+ # compute attention
66
+ x = attention(q, k, v)
67
+ x = x.reshape(b, s, c)
68
+
69
+ # output
70
+ x = self.proj(x)
71
+ x = F.dropout(x, self.proj_dropout, self.training)
72
+ return x
73
+
74
+
75
+ class SwiGLU(nn.Module):
76
+ def __init__(self, dim, mid_dim):
77
+ super().__init__()
78
+ self.dim = dim
79
+ self.mid_dim = mid_dim
80
+
81
+ # layers
82
+ self.fc1 = nn.Linear(dim, mid_dim)
83
+ self.fc2 = nn.Linear(dim, mid_dim)
84
+ self.fc3 = nn.Linear(mid_dim, dim)
85
+
86
+ def forward(self, x):
87
+ x = F.silu(self.fc1(x)) * self.fc2(x)
88
+ x = self.fc3(x)
89
+ return x
90
+
91
+
92
+ class AttentionBlock(nn.Module):
93
+ def __init__(
94
+ self,
95
+ dim,
96
+ mlp_ratio,
97
+ num_heads,
98
+ post_norm=False,
99
+ causal=False,
100
+ activation="quick_gelu",
101
+ attn_dropout=0.0,
102
+ proj_dropout=0.0,
103
+ norm_eps=1e-5,
104
+ ):
105
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
106
+ super().__init__()
107
+ self.dim = dim
108
+ self.mlp_ratio = mlp_ratio
109
+ self.num_heads = num_heads
110
+ self.post_norm = post_norm
111
+ self.causal = causal
112
+ self.norm_eps = norm_eps
113
+
114
+ # layers
115
+ self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
116
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
117
+ self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
118
+ if activation == "swi_glu":
119
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
120
+ else:
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(dim, int(dim * mlp_ratio)),
123
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
124
+ nn.Linear(int(dim * mlp_ratio), dim),
125
+ nn.Dropout(proj_dropout),
126
+ )
127
+
128
+ def forward(self, x):
129
+ if self.post_norm:
130
+ x = x + self.norm1(self.attn(x))
131
+ x = x + self.norm2(self.mlp(x))
132
+ else:
133
+ x = x + self.attn(self.norm1(x))
134
+ x = x + self.mlp(self.norm2(x))
135
+ return x
136
+
137
+
138
+ class AttentionPool(nn.Module):
139
+ def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
140
+ assert dim % num_heads == 0
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.mlp_ratio = mlp_ratio
144
+ self.num_heads = num_heads
145
+ self.head_dim = dim // num_heads
146
+ self.proj_dropout = proj_dropout
147
+ self.norm_eps = norm_eps
148
+
149
+ # layers
150
+ gain = 1.0 / math.sqrt(dim)
151
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
152
+ self.to_q = nn.Linear(dim, dim)
153
+ self.to_kv = nn.Linear(dim, dim * 2)
154
+ self.proj = nn.Linear(dim, dim)
155
+ self.norm = nn.LayerNorm(dim, eps=norm_eps)
156
+ self.mlp = nn.Sequential(
157
+ nn.Linear(dim, int(dim * mlp_ratio)),
158
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
159
+ nn.Linear(int(dim * mlp_ratio), dim),
160
+ nn.Dropout(proj_dropout),
161
+ )
162
+
163
+ def forward(self, x):
164
+ """
165
+ x: [B, L, C].
166
+ """
167
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
168
+
169
+ # compute query, key, value
170
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
171
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
172
+
173
+ # compute attention
174
+ x = attention(q, k, v, fa_version=2)
175
+ x = x.reshape(b, 1, c)
176
+
177
+ # output
178
+ x = self.proj(x)
179
+ x = F.dropout(x, self.proj_dropout, self.training)
180
+
181
+ # mlp
182
+ x = x + self.mlp(self.norm(x))
183
+ return x[:, 0]
184
+
185
+
186
+ class VisionTransformer(nn.Module):
187
+ def __init__(
188
+ self,
189
+ image_size=224,
190
+ patch_size=16,
191
+ dim=768,
192
+ mlp_ratio=4,
193
+ out_dim=512,
194
+ num_heads=12,
195
+ num_layers=12,
196
+ pool_type="token",
197
+ pre_norm=True,
198
+ post_norm=False,
199
+ activation="quick_gelu",
200
+ attn_dropout=0.0,
201
+ proj_dropout=0.0,
202
+ embedding_dropout=0.0,
203
+ norm_eps=1e-5,
204
+ ):
205
+ if image_size % patch_size != 0:
206
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
207
+ assert pool_type in ("token", "token_fc", "attn_pool")
208
+ out_dim = out_dim or dim
209
+ super().__init__()
210
+ self.image_size = image_size
211
+ self.patch_size = patch_size
212
+ self.num_patches = (image_size // patch_size) ** 2
213
+ self.dim = dim
214
+ self.mlp_ratio = mlp_ratio
215
+ self.out_dim = out_dim
216
+ self.num_heads = num_heads
217
+ self.num_layers = num_layers
218
+ self.pool_type = pool_type
219
+ self.post_norm = post_norm
220
+ self.norm_eps = norm_eps
221
+
222
+ # embeddings
223
+ gain = 1.0 / math.sqrt(dim)
224
+ self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
225
+ if pool_type in ("token", "token_fc"):
226
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
227
+ self.pos_embedding = nn.Parameter(
228
+ gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
229
+ )
230
+ self.dropout = nn.Dropout(embedding_dropout)
231
+
232
+ # transformer
233
+ self.pre_norm = nn.LayerNorm(dim, eps=norm_eps) if pre_norm else None
234
+ self.transformer = nn.Sequential(
235
+ *[
236
+ AttentionBlock(
237
+ dim,
238
+ mlp_ratio,
239
+ num_heads,
240
+ post_norm,
241
+ False,
242
+ activation,
243
+ attn_dropout,
244
+ proj_dropout,
245
+ norm_eps,
246
+ )
247
+ for _ in range(num_layers)
248
+ ]
249
+ )
250
+ self.post_norm = nn.LayerNorm(dim, eps=norm_eps)
251
+
252
+ # head
253
+ if pool_type == "token":
254
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
255
+ elif pool_type == "token_fc":
256
+ self.head = nn.Linear(dim, out_dim)
257
+ elif pool_type == "attn_pool":
258
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
259
+
260
+ def forward(self, x, interpolation=False, use_31_block=False):
261
+ b = x.size(0)
262
+
263
+ # embeddings
264
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
265
+ if self.pool_type in ("token", "token_fc"):
266
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
267
+ if interpolation:
268
+ e = pos_interpolate(self.pos_embedding, x.size(1))
269
+ else:
270
+ e = self.pos_embedding
271
+ e = e.to(dtype=x.dtype, device=x.device)
272
+ x = self.dropout(x + e)
273
+ if self.pre_norm is not None:
274
+ x = self.pre_norm(x)
275
+
276
+ # transformer
277
+ if use_31_block:
278
+ x = self.transformer[:-1](x)
279
+ else:
280
+ x = self.transformer(x)
281
+ return x
282
+
283
+
284
+ class XLMRobertaCLIP(nn.Module):
285
+ def __init__(
286
+ self,
287
+ embed_dim=1024,
288
+ image_size=224,
289
+ patch_size=14,
290
+ vision_dim=1280,
291
+ vision_mlp_ratio=4,
292
+ vision_heads=16,
293
+ vision_layers=32,
294
+ vision_pool="token",
295
+ vision_pre_norm=True,
296
+ vision_post_norm=False,
297
+ activation="gelu",
298
+ vocab_size=250002,
299
+ max_text_len=514,
300
+ type_size=1,
301
+ pad_id=1,
302
+ text_dim=1024,
303
+ text_heads=16,
304
+ text_layers=24,
305
+ text_post_norm=True,
306
+ text_dropout=0.0,
307
+ attn_dropout=0.0,
308
+ proj_dropout=0.0,
309
+ embedding_dropout=0.0,
310
+ norm_eps=1e-5,
311
+ ):
312
+ super().__init__()
313
+ self.embed_dim = embed_dim
314
+ self.image_size = image_size
315
+ self.patch_size = patch_size
316
+ self.vision_dim = vision_dim
317
+ self.vision_mlp_ratio = vision_mlp_ratio
318
+ self.vision_heads = vision_heads
319
+ self.vision_layers = vision_layers
320
+ self.vision_pre_norm = vision_pre_norm
321
+ self.vision_post_norm = vision_post_norm
322
+ self.activation = activation
323
+ self.vocab_size = vocab_size
324
+ self.max_text_len = max_text_len
325
+ self.type_size = type_size
326
+ self.pad_id = pad_id
327
+ self.text_dim = text_dim
328
+ self.text_heads = text_heads
329
+ self.text_layers = text_layers
330
+ self.text_post_norm = text_post_norm
331
+ self.norm_eps = norm_eps
332
+
333
+ # models
334
+ self.visual = VisionTransformer(
335
+ image_size=image_size,
336
+ patch_size=patch_size,
337
+ dim=vision_dim,
338
+ mlp_ratio=vision_mlp_ratio,
339
+ out_dim=embed_dim,
340
+ num_heads=vision_heads,
341
+ num_layers=vision_layers,
342
+ pool_type=vision_pool,
343
+ pre_norm=vision_pre_norm,
344
+ post_norm=vision_post_norm,
345
+ activation=activation,
346
+ attn_dropout=attn_dropout,
347
+ proj_dropout=proj_dropout,
348
+ embedding_dropout=embedding_dropout,
349
+ norm_eps=norm_eps,
350
+ )
351
+ self.textual = None
352
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
353
+
354
+ def forward(self, imgs, txt_ids):
355
+ """
356
+ imgs: [B, 3, H, W] of torch.float32.
357
+ - mean: [0.48145466, 0.4578275, 0.40821073]
358
+ - std: [0.26862954, 0.26130258, 0.27577711]
359
+ txt_ids: [B, L] of torch.long.
360
+ Encoded by data.CLIPTokenizer.
361
+ """
362
+ xi = self.visual(imgs)
363
+ xt = self.textual(txt_ids)
364
+ return xi, xt
365
+
366
+ def param_groups(self):
367
+ groups = [
368
+ {
369
+ "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
370
+ "weight_decay": 0.0,
371
+ },
372
+ {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
373
+ ]
374
+ return groups
375
+
376
+
377
+ def _clip(
378
+ pretrained_name=None,
379
+ model_cls=XLMRobertaCLIP,
380
+ dtype=torch.float32,
381
+ device="cpu",
382
+ **kwargs,
383
+ ):
384
+ # init model
385
+ with torch.device(device):
386
+ model = model_cls(**kwargs)
387
+
388
+ # set device
389
+ model = model.to(dtype=dtype, device=device)
390
+ output = (model,)
391
+
392
+ # mean and std
393
+ if "siglip" in pretrained_name.lower():
394
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
395
+ else:
396
+ mean = [0.48145466, 0.4578275, 0.40821073]
397
+ std = [0.26862954, 0.26130258, 0.27577711]
398
+
399
+ transforms = T.Compose(
400
+ [
401
+ T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
402
+ T.ToTensor(),
403
+ T.Normalize(mean=mean, std=std),
404
+ ]
405
+ )
406
+ output += (transforms,)
407
+
408
+ return output
409
+
410
+
411
+ def clip_xlm_roberta_vit_h_14(pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
412
+ cfg = dict(
413
+ embed_dim=1024,
414
+ image_size=224,
415
+ patch_size=14,
416
+ vision_dim=1280,
417
+ vision_mlp_ratio=4,
418
+ vision_heads=16,
419
+ vision_layers=32,
420
+ vision_pool="token",
421
+ activation="gelu",
422
+ vocab_size=250002,
423
+ max_text_len=514,
424
+ type_size=1,
425
+ pad_id=1,
426
+ text_dim=1024,
427
+ text_heads=16,
428
+ text_layers=24,
429
+ text_post_norm=True,
430
+ text_dropout=0.0,
431
+ attn_dropout=0.0,
432
+ proj_dropout=0.0,
433
+ embedding_dropout=0.0,
434
+ )
435
+ cfg.update(**kwargs)
436
+ return _clip(pretrained_name, XLMRobertaCLIP, **cfg)
437
+
438
+
439
+ class WanImageEncoderStateDictConverter(StateDictConverter):
440
+ def _from_diffusers(self, state_dict):
441
+ return state_dict
442
+
443
+ def _from_civitai(self, state_dict):
444
+ state_dict_ = {}
445
+ for name, param in state_dict.items():
446
+ if name.startswith("textual."):
447
+ continue
448
+ name = "model." + name
449
+ state_dict_[name] = param
450
+ return state_dict_
451
+
452
+ def convert(self, state_dict):
453
+ if "visual.transformer.9.norm2.weight" in state_dict:
454
+ state_dict = self._from_civitai(state_dict)
455
+ else:
456
+ state_dict = self._from_diffusers(state_dict)
457
+ return state_dict
458
+
459
+
460
+ class WanImageEncoder(PreTrainedModel):
461
+ converter = WanImageEncoderStateDictConverter()
462
+
463
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
464
+ super().__init__()
465
+ # init model
466
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device="cpu")
467
+
468
+ def encode_image(self, videos):
469
+ # preprocess
470
+ size = (self.model.image_size,) * 2
471
+ videos = torch.cat(
472
+ [
473
+ F.interpolate(
474
+ u,
475
+ size=size,
476
+ mode="bicubic",
477
+ align_corners=False,
478
+ )
479
+ for u in videos
480
+ ]
481
+ )
482
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
483
+
484
+ # forward
485
+ out = self.model.visual(videos, use_31_block=True)
486
+ return out
487
+
488
+ @classmethod
489
+ def from_state_dict(cls, state_dict, device, dtype):
490
+ with no_init_weights():
491
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
492
+ model.load_state_dict(state_dict, assign=True)
493
+ model.to(device=device, dtype=dtype, non_blocking=True)
494
+ return model