diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Dict
6
+
7
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
8
+ from diffsynth_engine.models.utils import no_init_weights
9
+
10
+
11
+ def fp16_clamp(x):
12
+ if x.dtype == torch.float16 and torch.isinf(x).any():
13
+ clamp = torch.finfo(x.dtype).max - 1000
14
+ x = torch.clamp(x, min=-clamp, max=clamp)
15
+ return x
16
+
17
+
18
+ class GELU(nn.Module):
19
+ def forward(self, x):
20
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
21
+
22
+
23
+ class T5LayerNorm(nn.Module):
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
32
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
33
+ x = x.type_as(self.weight)
34
+ return self.weight * x
35
+
36
+
37
+ class T5Attention(nn.Module):
38
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
39
+ assert dim_attn % num_heads == 0
40
+ super(T5Attention, self).__init__()
41
+ self.dim = dim
42
+ self.dim_attn = dim_attn
43
+ self.num_heads = num_heads
44
+ self.head_dim = dim_attn // num_heads
45
+
46
+ # layers
47
+ self.q = nn.Linear(dim, dim_attn, bias=False)
48
+ self.k = nn.Linear(dim, dim_attn, bias=False)
49
+ self.v = nn.Linear(dim, dim_attn, bias=False)
50
+ self.o = nn.Linear(dim_attn, dim, bias=False)
51
+ self.dropout = nn.Dropout(dropout)
52
+
53
+ def forward(self, x, context=None, mask=None, pos_bias=None):
54
+ """
55
+ x: [B, L1, C].
56
+ context: [B, L2, C] or None.
57
+ mask: [B, L2] or [B, L1, L2] or None.
58
+ """
59
+ # check inputs
60
+ context = x if context is None else context
61
+ b, n, c = x.size(0), self.num_heads, self.head_dim
62
+
63
+ # compute query, key, value
64
+ q = self.q(x).view(b, -1, n, c)
65
+ k = self.k(context).view(b, -1, n, c)
66
+ v = self.v(context).view(b, -1, n, c)
67
+
68
+ # attention bias
69
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
70
+ if pos_bias is not None:
71
+ attn_bias += pos_bias
72
+ if mask is not None:
73
+ assert mask.ndim in [2, 3]
74
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
75
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
76
+
77
+ # compute attention (T5 does not use scaling)
78
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
79
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
81
+
82
+ # output
83
+ x = x.reshape(b, -1, n * c)
84
+ x = self.o(x)
85
+ x = self.dropout(x)
86
+ return x
87
+
88
+
89
+ class T5FeedForward(nn.Module):
90
+ def __init__(self, dim, dim_ffn, dropout=0.0):
91
+ super(T5FeedForward, self).__init__()
92
+ self.dim = dim
93
+ self.dim_ffn = dim_ffn
94
+
95
+ # layers
96
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
97
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
98
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ def forward(self, x):
102
+ x = self.fc1(x) * self.gate(x)
103
+ x = self.dropout(x)
104
+ x = self.fc2(x)
105
+ x = self.dropout(x)
106
+ return x
107
+
108
+
109
+ class T5SelfAttention(nn.Module):
110
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0):
111
+ super(T5SelfAttention, self).__init__()
112
+ self.dim = dim
113
+ self.dim_attn = dim_attn
114
+ self.dim_ffn = dim_ffn
115
+ self.num_heads = num_heads
116
+ self.num_buckets = num_buckets
117
+ self.shared_pos = shared_pos
118
+
119
+ # layers
120
+ self.norm1 = T5LayerNorm(dim)
121
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
122
+ self.norm2 = T5LayerNorm(dim)
123
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
124
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
125
+
126
+ def forward(self, x, mask=None, pos_bias=None):
127
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
128
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
129
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
130
+ return x
131
+
132
+
133
+ class T5RelativeEmbedding(nn.Module):
134
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
135
+ super(T5RelativeEmbedding, self).__init__()
136
+ self.num_buckets = num_buckets
137
+ self.num_heads = num_heads
138
+ self.bidirectional = bidirectional
139
+ self.max_dist = max_dist
140
+
141
+ # layers
142
+ self.embedding = nn.Embedding(num_buckets, num_heads)
143
+
144
+ def forward(self, lq, lk):
145
+ device = self.embedding.weight.device
146
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
147
+ # torch.arange(lq).unsqueeze(1).to(device)
148
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
149
+ rel_pos = self._relative_position_bucket(rel_pos)
150
+ rel_pos_embeds = self.embedding(rel_pos)
151
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
152
+ return rel_pos_embeds.contiguous()
153
+
154
+ def _relative_position_bucket(self, rel_pos):
155
+ # preprocess
156
+ if self.bidirectional:
157
+ num_buckets = self.num_buckets // 2
158
+ rel_buckets = (rel_pos > 0).long() * num_buckets
159
+ rel_pos = torch.abs(rel_pos)
160
+ else:
161
+ num_buckets = self.num_buckets
162
+ rel_buckets = 0
163
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
164
+
165
+ # embeddings for small and large positions
166
+ max_exact = num_buckets // 2
167
+ rel_pos_large = (
168
+ max_exact
169
+ + (
170
+ torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
171
+ ).long()
172
+ )
173
+ rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
174
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
175
+ return rel_buckets
176
+
177
+
178
+ def init_weights(m):
179
+ if isinstance(m, T5LayerNorm):
180
+ nn.init.ones_(m.weight)
181
+ elif isinstance(m, T5FeedForward):
182
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
183
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
184
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
185
+ elif isinstance(m, T5Attention):
186
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
187
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
188
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
189
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
190
+ elif isinstance(m, T5RelativeEmbedding):
191
+ nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
192
+
193
+
194
+ class WanTextEncoderStateDictConverter(StateDictConverter):
195
+ def from_diffusers(self, state_dict):
196
+ return state_dict
197
+
198
+ def from_civitai(self, state_dict):
199
+ return state_dict
200
+
201
+ def convert(self, state_dict):
202
+ return state_dict
203
+
204
+
205
+ class WanTextEncoder(PreTrainedModel):
206
+ converter = WanTextEncoderStateDictConverter()
207
+
208
+ def __init__(
209
+ self,
210
+ vocab=256384,
211
+ dim=4096,
212
+ dim_attn=4096,
213
+ dim_ffn=10240,
214
+ num_heads=64,
215
+ num_layers=24,
216
+ num_buckets=32,
217
+ shared_pos=False,
218
+ dropout=0.0,
219
+ device: str = "cuda:0",
220
+ dtype: torch.dtype = torch.bfloat16,
221
+ ):
222
+ super().__init__()
223
+ self.dim = dim
224
+ self.dim_attn = dim_attn
225
+ self.dim_ffn = dim_ffn
226
+ self.num_heads = num_heads
227
+ self.num_layers = num_layers
228
+ self.num_buckets = num_buckets
229
+ self.shared_pos = shared_pos
230
+
231
+ # layers
232
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
234
+ self.dropout = nn.Dropout(dropout)
235
+ self.blocks = nn.ModuleList(
236
+ [
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
238
+ for _ in range(num_layers)
239
+ ]
240
+ )
241
+ self.norm = T5LayerNorm(dim)
242
+
243
+ def forward(self, ids, mask=None):
244
+ x = self.token_embedding(ids)
245
+ x = self.dropout(x)
246
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
247
+ for block in self.blocks:
248
+ x = block(x, mask, pos_bias=e)
249
+ x = self.norm(x)
250
+ x = self.dropout(x)
251
+ return x
252
+
253
+ @classmethod
254
+ def from_state_dict(
255
+ cls,
256
+ state_dict: Dict[str, torch.Tensor],
257
+ device: str,
258
+ dtype: torch.dtype,
259
+ ):
260
+ with no_init_weights():
261
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
262
+ model.load_state_dict(state_dict, assign=True)
263
+ model.to(device=device, dtype=dtype, non_blocking=True)
264
+ return model