diffsynth-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +25 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/models/__init__.py +0 -0
- diffsynth_engine/models/base.py +55 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +137 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +393 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +504 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/attention.py +200 -0
- diffsynth_engine/models/wan/wan_dit.py +431 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +17 -0
- diffsynth_engine/pipelines/base.py +216 -0
- diffsynth_engine/pipelines/flux_image.py +548 -0
- diffsynth_engine/pipelines/sd_image.py +386 -0
- diffsynth_engine/pipelines/sdxl_image.py +430 -0
- diffsynth_engine/pipelines/wan_video.py +481 -0
- diffsynth_engine/tokenizers/__init__.py +4 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +79 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +139 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +14 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +191 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
- diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
- diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
|
|
7
|
+
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
8
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm
|
|
9
|
+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
10
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
11
|
+
from diffsynth_engine.utils.constants import SD3_DIT_CONFIG_FILE
|
|
12
|
+
from diffsynth_engine.utils import logging
|
|
13
|
+
|
|
14
|
+
logger = logging.get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
with open(SD3_DIT_CONFIG_FILE, "r") as f:
|
|
17
|
+
config = json.load(f)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SD3DiTStateDictConverter(StateDictConverter):
|
|
21
|
+
def __init__(self):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
25
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
26
|
+
state_dict_ = {}
|
|
27
|
+
for name, param in state_dict.items():
|
|
28
|
+
if name in rename_dict:
|
|
29
|
+
if name == "pos_embed.pos_embed":
|
|
30
|
+
param = param.reshape((1, 192, 192, 1536))
|
|
31
|
+
state_dict_[rename_dict[name]] = param
|
|
32
|
+
elif name.endswith(".weight") or name.endswith(".bias"):
|
|
33
|
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
34
|
+
prefix = name[: -len(suffix)]
|
|
35
|
+
if prefix in rename_dict:
|
|
36
|
+
state_dict_[rename_dict[prefix] + suffix] = param
|
|
37
|
+
elif prefix.startswith("transformer_blocks."):
|
|
38
|
+
names = prefix.split(".")
|
|
39
|
+
names[0] = "blocks"
|
|
40
|
+
middle = ".".join(names[2:])
|
|
41
|
+
if middle in rename_dict:
|
|
42
|
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
43
|
+
state_dict_[name_] = param
|
|
44
|
+
return state_dict_
|
|
45
|
+
|
|
46
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
47
|
+
rename_dict = config["civitai"]["rename_dict"]
|
|
48
|
+
state_dict_ = {}
|
|
49
|
+
for name in state_dict:
|
|
50
|
+
if name in rename_dict:
|
|
51
|
+
param = state_dict[name]
|
|
52
|
+
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
|
53
|
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
|
54
|
+
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
|
55
|
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
|
56
|
+
elif name == "model.diffusion_model.pos_embed":
|
|
57
|
+
param = param.reshape((1, 192, 192, 1536))
|
|
58
|
+
if isinstance(rename_dict[name], str):
|
|
59
|
+
state_dict_[rename_dict[name]] = param
|
|
60
|
+
else:
|
|
61
|
+
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
|
62
|
+
state_dict_[name_] = param
|
|
63
|
+
return state_dict_
|
|
64
|
+
|
|
65
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
66
|
+
if "model.diffusion_model.context_embedder.weight" in state_dict:
|
|
67
|
+
state_dict = self._from_civitai(state_dict)
|
|
68
|
+
logger.info("use civitai format state dict")
|
|
69
|
+
elif "time_text_embed.timestep_embedder.linear_1" in state_dict:
|
|
70
|
+
state_dict = self._from_diffusers(state_dict)
|
|
71
|
+
logger.info("use diffusers format state dict")
|
|
72
|
+
else:
|
|
73
|
+
logger.info("use diffsynth format state dict")
|
|
74
|
+
return state_dict
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PatchEmbed(nn.Module):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
patch_size=2,
|
|
81
|
+
in_channels=16,
|
|
82
|
+
embed_dim=1536,
|
|
83
|
+
pos_embed_max_size=192,
|
|
84
|
+
device: str = "cuda:0",
|
|
85
|
+
dtype: torch.dtype = torch.float16,
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.pos_embed_max_size = pos_embed_max_size
|
|
89
|
+
self.patch_size = patch_size
|
|
90
|
+
|
|
91
|
+
self.proj = nn.Conv2d(
|
|
92
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, device=device, dtype=dtype
|
|
93
|
+
)
|
|
94
|
+
self.pos_embed = nn.Parameter(
|
|
95
|
+
torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536, device=device, dtype=dtype)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def cropped_pos_embed(self, height, width):
|
|
99
|
+
height = height // self.patch_size
|
|
100
|
+
width = width // self.patch_size
|
|
101
|
+
top = (self.pos_embed_max_size - height) // 2
|
|
102
|
+
left = (self.pos_embed_max_size - width) // 2
|
|
103
|
+
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
|
104
|
+
return spatial_pos_embed
|
|
105
|
+
|
|
106
|
+
def forward(self, latent):
|
|
107
|
+
height, width = latent.shape[-2:]
|
|
108
|
+
latent = self.proj(latent)
|
|
109
|
+
latent = latent.flatten(2).transpose(1, 2)
|
|
110
|
+
pos_embed = self.cropped_pos_embed(height, width)
|
|
111
|
+
return latent + pos_embed
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class JointAttention(nn.Module):
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
dim_a,
|
|
118
|
+
dim_b,
|
|
119
|
+
num_heads,
|
|
120
|
+
head_dim,
|
|
121
|
+
only_out_a=False,
|
|
122
|
+
device: str = "cuda:0",
|
|
123
|
+
dtype: torch.dtype = torch.float16,
|
|
124
|
+
):
|
|
125
|
+
super().__init__()
|
|
126
|
+
self.num_heads = num_heads
|
|
127
|
+
self.head_dim = head_dim
|
|
128
|
+
self.only_out_a = only_out_a
|
|
129
|
+
|
|
130
|
+
self.a_to_qkv = nn.Linear(dim_a, dim_a * 3, device=device, dtype=dtype)
|
|
131
|
+
self.b_to_qkv = nn.Linear(dim_b, dim_b * 3, device=device, dtype=dtype)
|
|
132
|
+
|
|
133
|
+
self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
134
|
+
if not only_out_a:
|
|
135
|
+
self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
136
|
+
|
|
137
|
+
def forward(self, hidden_states_a, hidden_states_b):
|
|
138
|
+
batch_size = hidden_states_a.shape[0]
|
|
139
|
+
|
|
140
|
+
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
|
141
|
+
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
142
|
+
q, k, v = qkv.chunk(3, dim=1)
|
|
143
|
+
|
|
144
|
+
hidden_states = nn.functional.scaled_dot_product_attention(q, k, v)
|
|
145
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
146
|
+
hidden_states = hidden_states.to(q.dtype)
|
|
147
|
+
hidden_states_a, hidden_states_b = (
|
|
148
|
+
hidden_states[:, : hidden_states_a.shape[1]],
|
|
149
|
+
hidden_states[:, hidden_states_a.shape[1] :],
|
|
150
|
+
)
|
|
151
|
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
|
152
|
+
if self.only_out_a:
|
|
153
|
+
return hidden_states_a
|
|
154
|
+
else:
|
|
155
|
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
|
156
|
+
return hidden_states_a, hidden_states_b
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class JointTransformerBlock(nn.Module):
|
|
160
|
+
def __init__(self, dim, num_attention_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.norm1_a = AdaLayerNorm(dim, device=device, dtype=dtype)
|
|
163
|
+
self.norm1_b = AdaLayerNorm(dim, device=device, dtype=dtype)
|
|
164
|
+
|
|
165
|
+
self.attn = JointAttention(
|
|
166
|
+
dim, dim, num_attention_heads, dim // num_attention_heads, device=device, dtype=dtype
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.norm2_a = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
170
|
+
self.ff_a = nn.Sequential(
|
|
171
|
+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
172
|
+
nn.GELU(approximate="tanh"),
|
|
173
|
+
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.norm2_b = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
177
|
+
self.ff_b = nn.Sequential(
|
|
178
|
+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
179
|
+
nn.GELU(approximate="tanh"),
|
|
180
|
+
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
184
|
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
185
|
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
186
|
+
|
|
187
|
+
# Attention
|
|
188
|
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
189
|
+
|
|
190
|
+
# Part A
|
|
191
|
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
192
|
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
193
|
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
194
|
+
|
|
195
|
+
# Part B
|
|
196
|
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
197
|
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
198
|
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
199
|
+
|
|
200
|
+
return hidden_states_a, hidden_states_b
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class JointTransformerFinalBlock(nn.Module):
|
|
204
|
+
def __init__(self, dim, num_attention_heads, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
|
|
205
|
+
super().__init__()
|
|
206
|
+
self.norm1_a = AdaLayerNorm(dim, device=device, dtype=dtype)
|
|
207
|
+
self.norm1_b = AdaLayerNorm(dim, single=True, device=device, dtype=dtype)
|
|
208
|
+
|
|
209
|
+
self.attn = JointAttention(
|
|
210
|
+
dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, device=device, dtype=dtype
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
self.norm2_a = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
|
214
|
+
self.ff_a = nn.Sequential(
|
|
215
|
+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
216
|
+
nn.GELU(approximate="tanh"),
|
|
217
|
+
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
221
|
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
222
|
+
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
223
|
+
|
|
224
|
+
# Attention
|
|
225
|
+
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
226
|
+
|
|
227
|
+
# Part A
|
|
228
|
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
229
|
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
230
|
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
231
|
+
|
|
232
|
+
return hidden_states_a, hidden_states_b
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class SD3DiT(PreTrainedModel):
|
|
236
|
+
converter = SD3DiTStateDictConverter()
|
|
237
|
+
|
|
238
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.pos_embedder = PatchEmbed(
|
|
241
|
+
patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192, device=device, dtype=dtype
|
|
242
|
+
)
|
|
243
|
+
self.time_embedder = TimestepEmbeddings(256, 1536, device=device, dtype=dtype)
|
|
244
|
+
self.pooled_text_embedder = nn.Sequential(
|
|
245
|
+
nn.Linear(2048, 1536, device=device, dtype=dtype),
|
|
246
|
+
nn.SiLU(),
|
|
247
|
+
nn.Linear(1536, 1536, device=device, dtype=dtype),
|
|
248
|
+
)
|
|
249
|
+
self.context_embedder = nn.Linear(4096, 1536, device=device, dtype=dtype)
|
|
250
|
+
self.blocks = nn.ModuleList(
|
|
251
|
+
[JointTransformerBlock(1536, 24, device=device, dtype=dtype) for _ in range(23)]
|
|
252
|
+
+ [JointTransformerFinalBlock(1536, 24, device=device, dtype=dtype)]
|
|
253
|
+
)
|
|
254
|
+
self.norm_out = AdaLayerNorm(1536, single=True, device=device, dtype=dtype)
|
|
255
|
+
self.proj_out = nn.Linear(1536, 64, device=device, dtype=dtype)
|
|
256
|
+
|
|
257
|
+
def forward(
|
|
258
|
+
self,
|
|
259
|
+
hidden_states,
|
|
260
|
+
timestep,
|
|
261
|
+
prompt_emb,
|
|
262
|
+
pooled_prompt_emb,
|
|
263
|
+
use_gradient_checkpointing=False,
|
|
264
|
+
):
|
|
265
|
+
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
266
|
+
prompt_emb = self.context_embedder(prompt_emb)
|
|
267
|
+
|
|
268
|
+
height, width = hidden_states.shape[-2:]
|
|
269
|
+
hidden_states = self.pos_embedder(hidden_states)
|
|
270
|
+
|
|
271
|
+
def create_custom_forward(module):
|
|
272
|
+
def custom_forward(*inputs):
|
|
273
|
+
return module(*inputs)
|
|
274
|
+
|
|
275
|
+
return custom_forward
|
|
276
|
+
|
|
277
|
+
for block in self.blocks:
|
|
278
|
+
if self.training and use_gradient_checkpointing:
|
|
279
|
+
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
280
|
+
create_custom_forward(block),
|
|
281
|
+
hidden_states,
|
|
282
|
+
prompt_emb,
|
|
283
|
+
conditioning,
|
|
284
|
+
use_reentrant=False,
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
|
288
|
+
|
|
289
|
+
hidden_states = self.norm_out(hidden_states, conditioning)
|
|
290
|
+
hidden_states = self.proj_out(hidden_states)
|
|
291
|
+
hidden_states = rearrange(
|
|
292
|
+
hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height // 2, W=width // 2
|
|
293
|
+
)
|
|
294
|
+
return hidden_states
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
298
|
+
with no_init_weights():
|
|
299
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
300
|
+
model.load_state_dict(state_dict, assign=True)
|
|
301
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
302
|
+
return model
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
from diffsynth_engine.models.components.t5 import T5EncoderModel
|
|
6
|
+
from diffsynth_engine.models.base import StateDictConverter
|
|
7
|
+
from diffsynth_engine.models.sd import SDTextEncoder
|
|
8
|
+
from diffsynth_engine.models.sdxl import SDXLTextEncoder2
|
|
9
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
10
|
+
from diffsynth_engine.utils.constants import SD3_TEXT_ENCODER_CONFIG_FILE
|
|
11
|
+
from diffsynth_engine.utils import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
with open(SD3_TEXT_ENCODER_CONFIG_FILE, "r") as f:
|
|
16
|
+
config = json.load(f)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SD3TextEncoder1StateDictConverter(StateDictConverter):
|
|
20
|
+
def __init__(self):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
24
|
+
rename_dict = config["diffusers"]["te1_rename_dict"]
|
|
25
|
+
attn_rename_dict = config["diffusers"]["te1_attn_rename_dict"]
|
|
26
|
+
state_dict_ = {}
|
|
27
|
+
for name in state_dict:
|
|
28
|
+
if name in rename_dict:
|
|
29
|
+
param = state_dict[name]
|
|
30
|
+
if name == "text_model.embeddings.position_embedding.weight":
|
|
31
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
32
|
+
state_dict_[rename_dict[name]] = param
|
|
33
|
+
elif name.startswith("text_model.encoder.layers."):
|
|
34
|
+
param = state_dict[name]
|
|
35
|
+
names = name.split(".")
|
|
36
|
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
37
|
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
38
|
+
state_dict_[name_] = param
|
|
39
|
+
return state_dict_
|
|
40
|
+
|
|
41
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
42
|
+
rename_dict = config["civitai"]["te1_rename_dict"]
|
|
43
|
+
state_dict_ = {}
|
|
44
|
+
for name in state_dict:
|
|
45
|
+
if name in rename_dict:
|
|
46
|
+
param = state_dict[name]
|
|
47
|
+
if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
|
|
48
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
49
|
+
state_dict_[rename_dict[name]] = param
|
|
50
|
+
return state_dict_
|
|
51
|
+
|
|
52
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
53
|
+
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
|
54
|
+
state_dict = self._from_civitai(state_dict)
|
|
55
|
+
logger.info("use civitai format state dict")
|
|
56
|
+
elif "text_model.embeddings.token_embedding.weight" in state_dict:
|
|
57
|
+
state_dict = self._from_diffusers(state_dict)
|
|
58
|
+
logger.info("use diffusers format state dict")
|
|
59
|
+
else:
|
|
60
|
+
logger.info("use diffsynth format state dict")
|
|
61
|
+
return state_dict
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SD3TextEncoder1(SDTextEncoder):
|
|
65
|
+
converter = SD3TextEncoder1StateDictConverter()
|
|
66
|
+
|
|
67
|
+
def __init__(self, vocab_size=49408, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
|
|
68
|
+
super().__init__(vocab_size=vocab_size, device=device, dtype=dtype)
|
|
69
|
+
|
|
70
|
+
def forward(self, input_ids, clip_skip=2):
|
|
71
|
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
|
72
|
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
73
|
+
for encoder_id, encoder in enumerate(self.encoders):
|
|
74
|
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
75
|
+
if encoder_id + clip_skip == len(self.encoders):
|
|
76
|
+
hidden_states = embeds
|
|
77
|
+
embeds = self.final_layer_norm(embeds)
|
|
78
|
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
|
79
|
+
return hidden_states, pooled_embeds
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_state_dict(
|
|
83
|
+
cls,
|
|
84
|
+
state_dict: Dict[str, torch.Tensor],
|
|
85
|
+
device: str,
|
|
86
|
+
dtype: torch.dtype,
|
|
87
|
+
vocab_size: int = 49408,
|
|
88
|
+
):
|
|
89
|
+
with no_init_weights():
|
|
90
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, vocab_size=vocab_size)
|
|
91
|
+
model.load_state_dict(state_dict)
|
|
92
|
+
return model
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SD3TextEncoder2StateDictConverter(StateDictConverter):
|
|
96
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
97
|
+
rename_dict = config["diffusers"]["te2_rename_dict"]
|
|
98
|
+
attn_rename_dict = config["diffusers"]["te2_attn_rename_dict"]
|
|
99
|
+
state_dict_ = {}
|
|
100
|
+
for name in state_dict:
|
|
101
|
+
if name in rename_dict:
|
|
102
|
+
param = state_dict[name]
|
|
103
|
+
if name == "text_model.embeddings.position_embedding.weight":
|
|
104
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
105
|
+
state_dict_[rename_dict[name]] = param
|
|
106
|
+
elif name.startswith("text_model.encoder.layers."):
|
|
107
|
+
param = state_dict[name]
|
|
108
|
+
names = name.split(".")
|
|
109
|
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
110
|
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
111
|
+
state_dict_[name_] = param
|
|
112
|
+
return state_dict_
|
|
113
|
+
|
|
114
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
115
|
+
rename_dict = config["civitai"]["te2_rename_dict"]
|
|
116
|
+
state_dict_ = {}
|
|
117
|
+
for name in state_dict:
|
|
118
|
+
if name in rename_dict:
|
|
119
|
+
param = state_dict[name]
|
|
120
|
+
if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
|
|
121
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
122
|
+
state_dict_[rename_dict[name]] = param
|
|
123
|
+
return state_dict_
|
|
124
|
+
|
|
125
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
126
|
+
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
|
127
|
+
state_dict = self._from_civitai(state_dict)
|
|
128
|
+
logger.info("use civitai format state dict")
|
|
129
|
+
elif "text_model.final_layer_norm.weight" in state_dict:
|
|
130
|
+
state_dict = self._from_diffusers(state_dict)
|
|
131
|
+
logger.info("use diffusers format state dict")
|
|
132
|
+
else:
|
|
133
|
+
logger.info("use diffsynth format state dict")
|
|
134
|
+
return state_dict
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class SD3TextEncoder2(SDXLTextEncoder2):
|
|
138
|
+
converter = SD3TextEncoder2StateDictConverter()
|
|
139
|
+
|
|
140
|
+
def __init__(self):
|
|
141
|
+
super().__init__()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class SD3TextEncoder3(T5EncoderModel):
|
|
145
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
|
|
146
|
+
super().__init__(
|
|
147
|
+
embed_dim=4096,
|
|
148
|
+
vocab_size=32128,
|
|
149
|
+
num_encoder_layers=24,
|
|
150
|
+
d_ff=10240,
|
|
151
|
+
num_heads=64,
|
|
152
|
+
relative_attention_num_buckets=32,
|
|
153
|
+
relative_attention_max_distance=128,
|
|
154
|
+
dropout_rate=0.0,
|
|
155
|
+
eps=1e-6,
|
|
156
|
+
device=device,
|
|
157
|
+
dtype=dtype,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def forward(self, input_ids):
|
|
161
|
+
outputs = super().forward(input_ids=input_ids)
|
|
162
|
+
prompt_emb = outputs.last_hidden_state
|
|
163
|
+
return prompt_emb
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from diffsynth_engine.models.components.vae import VAEDecoder, VAEEncoder
|
|
5
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SD3VAEEncoder(VAEEncoder):
|
|
9
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
|
|
10
|
+
super().__init__(
|
|
11
|
+
latent_channels=16,
|
|
12
|
+
scaling_factor=1.5305,
|
|
13
|
+
shift_factor=0.0609,
|
|
14
|
+
use_quant_conv=False,
|
|
15
|
+
device=device,
|
|
16
|
+
dtype=dtype,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
21
|
+
with no_init_weights():
|
|
22
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
23
|
+
model.load_state_dict(state_dict)
|
|
24
|
+
return model
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SD3VAEDecoder(VAEDecoder):
|
|
28
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
|
|
29
|
+
super().__init__(
|
|
30
|
+
latent_channels=16,
|
|
31
|
+
scaling_factor=1.5305,
|
|
32
|
+
shift_factor=0.0609,
|
|
33
|
+
use_post_quant_conv=False,
|
|
34
|
+
device=device,
|
|
35
|
+
dtype=dtype,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
40
|
+
with no_init_weights():
|
|
41
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
42
|
+
model.load_state_dict(state_dict)
|
|
43
|
+
return model
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2, config as sdxl_text_encoder_config
|
|
2
|
+
from .sdxl_unet import SDXLUNet, config as sdxl_unet_config
|
|
3
|
+
from .sdxl_vae import SDXLVAEDecoder, SDXLVAEEncoder
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"SDXLTextEncoder",
|
|
7
|
+
"SDXLTextEncoder2",
|
|
8
|
+
"SDXLUNet",
|
|
9
|
+
"SDXLVAEDecoder",
|
|
10
|
+
"SDXLVAEEncoder",
|
|
11
|
+
"sdxl_text_encoder_config",
|
|
12
|
+
"sdxl_unet_config",
|
|
13
|
+
]
|