diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,302 @@
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict
5
+ from einops import rearrange
6
+
7
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
8
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm
9
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
10
+ from diffsynth_engine.models.utils import no_init_weights
11
+ from diffsynth_engine.utils.constants import SD3_DIT_CONFIG_FILE
12
+ from diffsynth_engine.utils import logging
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ with open(SD3_DIT_CONFIG_FILE, "r") as f:
17
+ config = json.load(f)
18
+
19
+
20
+ class SD3DiTStateDictConverter(StateDictConverter):
21
+ def __init__(self):
22
+ pass
23
+
24
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
25
+ rename_dict = config["diffusers"]["rename_dict"]
26
+ state_dict_ = {}
27
+ for name, param in state_dict.items():
28
+ if name in rename_dict:
29
+ if name == "pos_embed.pos_embed":
30
+ param = param.reshape((1, 192, 192, 1536))
31
+ state_dict_[rename_dict[name]] = param
32
+ elif name.endswith(".weight") or name.endswith(".bias"):
33
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
34
+ prefix = name[: -len(suffix)]
35
+ if prefix in rename_dict:
36
+ state_dict_[rename_dict[prefix] + suffix] = param
37
+ elif prefix.startswith("transformer_blocks."):
38
+ names = prefix.split(".")
39
+ names[0] = "blocks"
40
+ middle = ".".join(names[2:])
41
+ if middle in rename_dict:
42
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
43
+ state_dict_[name_] = param
44
+ return state_dict_
45
+
46
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
47
+ rename_dict = config["civitai"]["rename_dict"]
48
+ state_dict_ = {}
49
+ for name in state_dict:
50
+ if name in rename_dict:
51
+ param = state_dict[name]
52
+ if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
53
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
54
+ elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
55
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
56
+ elif name == "model.diffusion_model.pos_embed":
57
+ param = param.reshape((1, 192, 192, 1536))
58
+ if isinstance(rename_dict[name], str):
59
+ state_dict_[rename_dict[name]] = param
60
+ else:
61
+ name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
62
+ state_dict_[name_] = param
63
+ return state_dict_
64
+
65
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
66
+ if "model.diffusion_model.context_embedder.weight" in state_dict:
67
+ state_dict = self._from_civitai(state_dict)
68
+ logger.info("use civitai format state dict")
69
+ elif "time_text_embed.timestep_embedder.linear_1" in state_dict:
70
+ state_dict = self._from_diffusers(state_dict)
71
+ logger.info("use diffusers format state dict")
72
+ else:
73
+ logger.info("use diffsynth format state dict")
74
+ return state_dict
75
+
76
+
77
+ class PatchEmbed(nn.Module):
78
+ def __init__(
79
+ self,
80
+ patch_size=2,
81
+ in_channels=16,
82
+ embed_dim=1536,
83
+ pos_embed_max_size=192,
84
+ device: str = "cuda:0",
85
+ dtype: torch.dtype = torch.float16,
86
+ ):
87
+ super().__init__()
88
+ self.pos_embed_max_size = pos_embed_max_size
89
+ self.patch_size = patch_size
90
+
91
+ self.proj = nn.Conv2d(
92
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, device=device, dtype=dtype
93
+ )
94
+ self.pos_embed = nn.Parameter(
95
+ torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536, device=device, dtype=dtype)
96
+ )
97
+
98
+ def cropped_pos_embed(self, height, width):
99
+ height = height // self.patch_size
100
+ width = width // self.patch_size
101
+ top = (self.pos_embed_max_size - height) // 2
102
+ left = (self.pos_embed_max_size - width) // 2
103
+ spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
104
+ return spatial_pos_embed
105
+
106
+ def forward(self, latent):
107
+ height, width = latent.shape[-2:]
108
+ latent = self.proj(latent)
109
+ latent = latent.flatten(2).transpose(1, 2)
110
+ pos_embed = self.cropped_pos_embed(height, width)
111
+ return latent + pos_embed
112
+
113
+
114
+ class JointAttention(nn.Module):
115
+ def __init__(
116
+ self,
117
+ dim_a,
118
+ dim_b,
119
+ num_heads,
120
+ head_dim,
121
+ only_out_a=False,
122
+ device: str = "cuda:0",
123
+ dtype: torch.dtype = torch.float16,
124
+ ):
125
+ super().__init__()
126
+ self.num_heads = num_heads
127
+ self.head_dim = head_dim
128
+ self.only_out_a = only_out_a
129
+
130
+ self.a_to_qkv = nn.Linear(dim_a, dim_a * 3, device=device, dtype=dtype)
131
+ self.b_to_qkv = nn.Linear(dim_b, dim_b * 3, device=device, dtype=dtype)
132
+
133
+ self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
134
+ if not only_out_a:
135
+ self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
136
+
137
+ def forward(self, hidden_states_a, hidden_states_b):
138
+ batch_size = hidden_states_a.shape[0]
139
+
140
+ qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
141
+ qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
142
+ q, k, v = qkv.chunk(3, dim=1)
143
+
144
+ hidden_states = nn.functional.scaled_dot_product_attention(q, k, v)
145
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
146
+ hidden_states = hidden_states.to(q.dtype)
147
+ hidden_states_a, hidden_states_b = (
148
+ hidden_states[:, : hidden_states_a.shape[1]],
149
+ hidden_states[:, hidden_states_a.shape[1] :],
150
+ )
151
+ hidden_states_a = self.a_to_out(hidden_states_a)
152
+ if self.only_out_a:
153
+ return hidden_states_a
154
+ else:
155
+ hidden_states_b = self.b_to_out(hidden_states_b)
156
+ return hidden_states_a, hidden_states_b
157
+
158
+
159
+ class JointTransformerBlock(nn.Module):
160
+ def __init__(self, dim, num_attention_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
161
+ super().__init__()
162
+ self.norm1_a = AdaLayerNorm(dim, device=device, dtype=dtype)
163
+ self.norm1_b = AdaLayerNorm(dim, device=device, dtype=dtype)
164
+
165
+ self.attn = JointAttention(
166
+ dim, dim, num_attention_heads, dim // num_attention_heads, device=device, dtype=dtype
167
+ )
168
+
169
+ self.norm2_a = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
170
+ self.ff_a = nn.Sequential(
171
+ nn.Linear(dim, dim * 4, device=device, dtype=dtype),
172
+ nn.GELU(approximate="tanh"),
173
+ nn.Linear(dim * 4, dim, device=device, dtype=dtype),
174
+ )
175
+
176
+ self.norm2_b = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
177
+ self.ff_b = nn.Sequential(
178
+ nn.Linear(dim, dim * 4, device=device, dtype=dtype),
179
+ nn.GELU(approximate="tanh"),
180
+ nn.Linear(dim * 4, dim, device=device, dtype=dtype),
181
+ )
182
+
183
+ def forward(self, hidden_states_a, hidden_states_b, temb):
184
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
185
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
186
+
187
+ # Attention
188
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
189
+
190
+ # Part A
191
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
192
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
193
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
194
+
195
+ # Part B
196
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
197
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
198
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
199
+
200
+ return hidden_states_a, hidden_states_b
201
+
202
+
203
+ class JointTransformerFinalBlock(nn.Module):
204
+ def __init__(self, dim, num_attention_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
205
+ super().__init__()
206
+ self.norm1_a = AdaLayerNorm(dim, device=device, dtype=dtype)
207
+ self.norm1_b = AdaLayerNorm(dim, single=True, device=device, dtype=dtype)
208
+
209
+ self.attn = JointAttention(
210
+ dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, device=device, dtype=dtype
211
+ )
212
+
213
+ self.norm2_a = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
214
+ self.ff_a = nn.Sequential(
215
+ nn.Linear(dim, dim * 4, device=device, dtype=dtype),
216
+ nn.GELU(approximate="tanh"),
217
+ nn.Linear(dim * 4, dim, device=device, dtype=dtype),
218
+ )
219
+
220
+ def forward(self, hidden_states_a, hidden_states_b, temb):
221
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
222
+ norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
223
+
224
+ # Attention
225
+ attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
226
+
227
+ # Part A
228
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
229
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
230
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
231
+
232
+ return hidden_states_a, hidden_states_b
233
+
234
+
235
+ class SD3DiT(PreTrainedModel):
236
+ converter = SD3DiTStateDictConverter()
237
+
238
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
239
+ super().__init__()
240
+ self.pos_embedder = PatchEmbed(
241
+ patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192, device=device, dtype=dtype
242
+ )
243
+ self.time_embedder = TimestepEmbeddings(256, 1536, device=device, dtype=dtype)
244
+ self.pooled_text_embedder = nn.Sequential(
245
+ nn.Linear(2048, 1536, device=device, dtype=dtype),
246
+ nn.SiLU(),
247
+ nn.Linear(1536, 1536, device=device, dtype=dtype),
248
+ )
249
+ self.context_embedder = nn.Linear(4096, 1536, device=device, dtype=dtype)
250
+ self.blocks = nn.ModuleList(
251
+ [JointTransformerBlock(1536, 24, device=device, dtype=dtype) for _ in range(23)]
252
+ + [JointTransformerFinalBlock(1536, 24, device=device, dtype=dtype)]
253
+ )
254
+ self.norm_out = AdaLayerNorm(1536, single=True, device=device, dtype=dtype)
255
+ self.proj_out = nn.Linear(1536, 64, device=device, dtype=dtype)
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states,
260
+ timestep,
261
+ prompt_emb,
262
+ pooled_prompt_emb,
263
+ use_gradient_checkpointing=False,
264
+ ):
265
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
266
+ prompt_emb = self.context_embedder(prompt_emb)
267
+
268
+ height, width = hidden_states.shape[-2:]
269
+ hidden_states = self.pos_embedder(hidden_states)
270
+
271
+ def create_custom_forward(module):
272
+ def custom_forward(*inputs):
273
+ return module(*inputs)
274
+
275
+ return custom_forward
276
+
277
+ for block in self.blocks:
278
+ if self.training and use_gradient_checkpointing:
279
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
280
+ create_custom_forward(block),
281
+ hidden_states,
282
+ prompt_emb,
283
+ conditioning,
284
+ use_reentrant=False,
285
+ )
286
+ else:
287
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
288
+
289
+ hidden_states = self.norm_out(hidden_states, conditioning)
290
+ hidden_states = self.proj_out(hidden_states)
291
+ hidden_states = rearrange(
292
+ hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height // 2, W=width // 2
293
+ )
294
+ return hidden_states
295
+
296
+ @classmethod
297
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
298
+ with no_init_weights():
299
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
300
+ model.load_state_dict(state_dict, assign=True)
301
+ model.to(device=device, dtype=dtype, non_blocking=True)
302
+ return model
@@ -0,0 +1,163 @@
1
+ import json
2
+ import torch
3
+ from typing import Dict
4
+
5
+ from diffsynth_engine.models.components.t5 import T5EncoderModel
6
+ from diffsynth_engine.models.base import StateDictConverter
7
+ from diffsynth_engine.models.sd import SDTextEncoder
8
+ from diffsynth_engine.models.sdxl import SDXLTextEncoder2
9
+ from diffsynth_engine.models.utils import no_init_weights
10
+ from diffsynth_engine.utils.constants import SD3_TEXT_ENCODER_CONFIG_FILE
11
+ from diffsynth_engine.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ with open(SD3_TEXT_ENCODER_CONFIG_FILE, "r") as f:
16
+ config = json.load(f)
17
+
18
+
19
+ class SD3TextEncoder1StateDictConverter(StateDictConverter):
20
+ def __init__(self):
21
+ pass
22
+
23
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
24
+ rename_dict = config["diffusers"]["te1_rename_dict"]
25
+ attn_rename_dict = config["diffusers"]["te1_attn_rename_dict"]
26
+ state_dict_ = {}
27
+ for name in state_dict:
28
+ if name in rename_dict:
29
+ param = state_dict[name]
30
+ if name == "text_model.embeddings.position_embedding.weight":
31
+ param = param.reshape((1, param.shape[0], param.shape[1]))
32
+ state_dict_[rename_dict[name]] = param
33
+ elif name.startswith("text_model.encoder.layers."):
34
+ param = state_dict[name]
35
+ names = name.split(".")
36
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
37
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
38
+ state_dict_[name_] = param
39
+ return state_dict_
40
+
41
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
42
+ rename_dict = config["civitai"]["te1_rename_dict"]
43
+ state_dict_ = {}
44
+ for name in state_dict:
45
+ if name in rename_dict:
46
+ param = state_dict[name]
47
+ if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
48
+ param = param.reshape((1, param.shape[0], param.shape[1]))
49
+ state_dict_[rename_dict[name]] = param
50
+ return state_dict_
51
+
52
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
53
+ if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
54
+ state_dict = self._from_civitai(state_dict)
55
+ logger.info("use civitai format state dict")
56
+ elif "text_model.embeddings.token_embedding.weight" in state_dict:
57
+ state_dict = self._from_diffusers(state_dict)
58
+ logger.info("use diffusers format state dict")
59
+ else:
60
+ logger.info("use diffsynth format state dict")
61
+ return state_dict
62
+
63
+
64
+ class SD3TextEncoder1(SDTextEncoder):
65
+ converter = SD3TextEncoder1StateDictConverter()
66
+
67
+ def __init__(self, vocab_size=49408, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
68
+ super().__init__(vocab_size=vocab_size, device=device, dtype=dtype)
69
+
70
+ def forward(self, input_ids, clip_skip=2):
71
+ embeds = self.token_embedding(input_ids) + self.position_embeds
72
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
73
+ for encoder_id, encoder in enumerate(self.encoders):
74
+ embeds = encoder(embeds, attn_mask=attn_mask)
75
+ if encoder_id + clip_skip == len(self.encoders):
76
+ hidden_states = embeds
77
+ embeds = self.final_layer_norm(embeds)
78
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
79
+ return hidden_states, pooled_embeds
80
+
81
+ @classmethod
82
+ def from_state_dict(
83
+ cls,
84
+ state_dict: Dict[str, torch.Tensor],
85
+ device: str,
86
+ dtype: torch.dtype,
87
+ vocab_size: int = 49408,
88
+ ):
89
+ with no_init_weights():
90
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, vocab_size=vocab_size)
91
+ model.load_state_dict(state_dict)
92
+ return model
93
+
94
+
95
+ class SD3TextEncoder2StateDictConverter(StateDictConverter):
96
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
97
+ rename_dict = config["diffusers"]["te2_rename_dict"]
98
+ attn_rename_dict = config["diffusers"]["te2_attn_rename_dict"]
99
+ state_dict_ = {}
100
+ for name in state_dict:
101
+ if name in rename_dict:
102
+ param = state_dict[name]
103
+ if name == "text_model.embeddings.position_embedding.weight":
104
+ param = param.reshape((1, param.shape[0], param.shape[1]))
105
+ state_dict_[rename_dict[name]] = param
106
+ elif name.startswith("text_model.encoder.layers."):
107
+ param = state_dict[name]
108
+ names = name.split(".")
109
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
110
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
111
+ state_dict_[name_] = param
112
+ return state_dict_
113
+
114
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
115
+ rename_dict = config["civitai"]["te2_rename_dict"]
116
+ state_dict_ = {}
117
+ for name in state_dict:
118
+ if name in rename_dict:
119
+ param = state_dict[name]
120
+ if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
121
+ param = param.reshape((1, param.shape[0], param.shape[1]))
122
+ state_dict_[rename_dict[name]] = param
123
+ return state_dict_
124
+
125
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
126
+ if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
127
+ state_dict = self._from_civitai(state_dict)
128
+ logger.info("use civitai format state dict")
129
+ elif "text_model.final_layer_norm.weight" in state_dict:
130
+ state_dict = self._from_diffusers(state_dict)
131
+ logger.info("use diffusers format state dict")
132
+ else:
133
+ logger.info("use diffsynth format state dict")
134
+ return state_dict
135
+
136
+
137
+ class SD3TextEncoder2(SDXLTextEncoder2):
138
+ converter = SD3TextEncoder2StateDictConverter()
139
+
140
+ def __init__(self):
141
+ super().__init__()
142
+
143
+
144
+ class SD3TextEncoder3(T5EncoderModel):
145
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
146
+ super().__init__(
147
+ embed_dim=4096,
148
+ vocab_size=32128,
149
+ num_encoder_layers=24,
150
+ d_ff=10240,
151
+ num_heads=64,
152
+ relative_attention_num_buckets=32,
153
+ relative_attention_max_distance=128,
154
+ dropout_rate=0.0,
155
+ eps=1e-6,
156
+ device=device,
157
+ dtype=dtype,
158
+ )
159
+
160
+ def forward(self, input_ids):
161
+ outputs = super().forward(input_ids=input_ids)
162
+ prompt_emb = outputs.last_hidden_state
163
+ return prompt_emb
@@ -0,0 +1,43 @@
1
+ import torch
2
+ from typing import Dict
3
+
4
+ from diffsynth_engine.models.components.vae import VAEDecoder, VAEEncoder
5
+ from diffsynth_engine.models.utils import no_init_weights
6
+
7
+
8
+ class SD3VAEEncoder(VAEEncoder):
9
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
10
+ super().__init__(
11
+ latent_channels=16,
12
+ scaling_factor=1.5305,
13
+ shift_factor=0.0609,
14
+ use_quant_conv=False,
15
+ device=device,
16
+ dtype=dtype,
17
+ )
18
+
19
+ @classmethod
20
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
21
+ with no_init_weights():
22
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
23
+ model.load_state_dict(state_dict)
24
+ return model
25
+
26
+
27
+ class SD3VAEDecoder(VAEDecoder):
28
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
29
+ super().__init__(
30
+ latent_channels=16,
31
+ scaling_factor=1.5305,
32
+ shift_factor=0.0609,
33
+ use_post_quant_conv=False,
34
+ device=device,
35
+ dtype=dtype,
36
+ )
37
+
38
+ @classmethod
39
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
40
+ with no_init_weights():
41
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
42
+ model.load_state_dict(state_dict)
43
+ return model
@@ -0,0 +1,13 @@
1
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2, config as sdxl_text_encoder_config
2
+ from .sdxl_unet import SDXLUNet, config as sdxl_unet_config
3
+ from .sdxl_vae import SDXLVAEDecoder, SDXLVAEEncoder
4
+
5
+ __all__ = [
6
+ "SDXLTextEncoder",
7
+ "SDXLTextEncoder2",
8
+ "SDXLUNet",
9
+ "SDXLVAEDecoder",
10
+ "SDXLVAEEncoder",
11
+ "sdxl_text_encoder_config",
12
+ "sdxl_unet_config",
13
+ ]