diffsynth-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +25 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/models/__init__.py +0 -0
- diffsynth_engine/models/base.py +55 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +137 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +393 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +504 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/attention.py +200 -0
- diffsynth_engine/models/wan/wan_dit.py +431 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +17 -0
- diffsynth_engine/pipelines/base.py +216 -0
- diffsynth_engine/pipelines/flux_image.py +548 -0
- diffsynth_engine/pipelines/sd_image.py +386 -0
- diffsynth_engine/pipelines/sdxl_image.py +430 -0
- diffsynth_engine/pipelines/wan_video.py +481 -0
- diffsynth_engine/tokenizers/__init__.py +4 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +79 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +139 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +14 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +191 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
- diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
- diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.components.clip import CLIPEncoderLayer
|
|
7
|
+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
|
|
8
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
9
|
+
from diffsynth_engine.utils.constants import SDXL_TEXT_ENCODER_CONFIG_FILE
|
|
10
|
+
from diffsynth_engine.utils import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
with open(SDXL_TEXT_ENCODER_CONFIG_FILE, "r") as f:
|
|
15
|
+
config = json.load(f)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SDXLTextEncoderStateDictConverter(StateDictConverter):
|
|
19
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
20
|
+
rename_dict = config["diffusers"]["te1_rename_dict"]
|
|
21
|
+
attn_rename_dict = config["diffusers"]["te1_attn_rename_dict"]
|
|
22
|
+
state_dict_ = {}
|
|
23
|
+
for name in state_dict:
|
|
24
|
+
if name in rename_dict:
|
|
25
|
+
param = state_dict[name]
|
|
26
|
+
if name == "text_model.embeddings.position_embedding.weight":
|
|
27
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
28
|
+
state_dict_[rename_dict[name]] = param
|
|
29
|
+
elif name.startswith("text_model.encoder.layers."):
|
|
30
|
+
param = state_dict[name]
|
|
31
|
+
names = name.split(".")
|
|
32
|
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
33
|
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
34
|
+
if layer_id == "11":
|
|
35
|
+
# we don't need the last layer
|
|
36
|
+
continue
|
|
37
|
+
state_dict_[name_] = param
|
|
38
|
+
return state_dict_
|
|
39
|
+
|
|
40
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
41
|
+
rename_dict = config["civitai"]["te1_rename_dict"]
|
|
42
|
+
state_dict_ = {}
|
|
43
|
+
for name, param in state_dict.items():
|
|
44
|
+
if not name.startswith("conditioner.embedders.0"):
|
|
45
|
+
continue
|
|
46
|
+
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
|
47
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
48
|
+
name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding"
|
|
49
|
+
suffix = ""
|
|
50
|
+
else:
|
|
51
|
+
name, suffix = split_suffix(name)
|
|
52
|
+
if name in rename_dict:
|
|
53
|
+
name_ = rename_dict[name] + suffix
|
|
54
|
+
state_dict_[name_] = param
|
|
55
|
+
return state_dict_
|
|
56
|
+
|
|
57
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
58
|
+
if "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight" in state_dict:
|
|
59
|
+
state_dict = self._from_civitai(state_dict)
|
|
60
|
+
logger.info("use civitai format state dict")
|
|
61
|
+
elif "text_model.final_layer_norm.weight" in state_dict:
|
|
62
|
+
state_dict = self._from_diffusers(state_dict)
|
|
63
|
+
logger.info("use diffusers format state dict")
|
|
64
|
+
else:
|
|
65
|
+
logger.info("use diffsynth format state dict")
|
|
66
|
+
return state_dict
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SDXLTextEncoder2StateDictConverter(StateDictConverter):
|
|
70
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
71
|
+
rename_dict = config["diffusers"]["te2_rename_dict"]
|
|
72
|
+
attn_rename_dict = config["diffusers"]["te2_attn_rename_dict"]
|
|
73
|
+
state_dict_ = {}
|
|
74
|
+
for name in state_dict:
|
|
75
|
+
if name in rename_dict:
|
|
76
|
+
param = state_dict[name]
|
|
77
|
+
if name == "text_model.embeddings.position_embedding.weight":
|
|
78
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
79
|
+
state_dict_[rename_dict[name]] = param
|
|
80
|
+
elif name.startswith("text_model.encoder.layers."):
|
|
81
|
+
param = state_dict[name]
|
|
82
|
+
names = name.split(".")
|
|
83
|
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
84
|
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
|
85
|
+
state_dict_[name_] = param
|
|
86
|
+
return state_dict_
|
|
87
|
+
|
|
88
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
89
|
+
rename_dict = config["civitai"]["te2_rename_dict"]
|
|
90
|
+
state_dict_ = {}
|
|
91
|
+
for name, param in state_dict.items():
|
|
92
|
+
if not name.startswith("conditioner.embedders.1"):
|
|
93
|
+
continue
|
|
94
|
+
if name.endswith(".in_proj_weight"):
|
|
95
|
+
name = name.replace(".in_proj_weight", ".in_proj")
|
|
96
|
+
suffix = ".weight"
|
|
97
|
+
elif name.endswith(".in_proj_bias"):
|
|
98
|
+
name = name.replace(".in_proj_bias", ".in_proj")
|
|
99
|
+
suffix = ".bias"
|
|
100
|
+
elif name == "conditioner.embedders.1.model.text_projection":
|
|
101
|
+
name = "conditioner.embedders.1.model.text_projection"
|
|
102
|
+
suffix = ".weight"
|
|
103
|
+
param = param.T
|
|
104
|
+
elif name == "conditioner.embedders.1.model.positional_embedding":
|
|
105
|
+
name = "conditioner.embedders.1.model.positional_embedding"
|
|
106
|
+
suffix = ""
|
|
107
|
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
108
|
+
else:
|
|
109
|
+
name, suffix = split_suffix(name)
|
|
110
|
+
|
|
111
|
+
if name in rename_dict:
|
|
112
|
+
if isinstance(rename_dict[name], str):
|
|
113
|
+
name_ = rename_dict[name] + suffix
|
|
114
|
+
state_dict_[name_] = param
|
|
115
|
+
else:
|
|
116
|
+
length = param.shape[0] // 3
|
|
117
|
+
for i, rename in enumerate(rename_dict[name]):
|
|
118
|
+
name_ = rename + suffix
|
|
119
|
+
state_dict_[name_] = param[i * length : i * length + length]
|
|
120
|
+
return state_dict_
|
|
121
|
+
|
|
122
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
123
|
+
if "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight" in state_dict:
|
|
124
|
+
state_dict = self._from_civitai(state_dict)
|
|
125
|
+
logger.info("use civitai format state dict")
|
|
126
|
+
elif "text_model.final_layer_norm.weight" in state_dict:
|
|
127
|
+
state_dict = self._from_diffusers(state_dict)
|
|
128
|
+
logger.info("use diffusers format state dict")
|
|
129
|
+
else:
|
|
130
|
+
logger.info("use diffsynth format state dict")
|
|
131
|
+
return state_dict
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class SDXLTextEncoder(PreTrainedModel):
|
|
135
|
+
converter = SDXLTextEncoderStateDictConverter()
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
embed_dim=768,
|
|
140
|
+
vocab_size=49408,
|
|
141
|
+
max_position_embeddings=77,
|
|
142
|
+
num_encoder_layers=11,
|
|
143
|
+
encoder_intermediate_size=3072,
|
|
144
|
+
device: str = "cuda:0",
|
|
145
|
+
dtype: torch.dtype = torch.float16,
|
|
146
|
+
):
|
|
147
|
+
super().__init__()
|
|
148
|
+
|
|
149
|
+
# token_embedding
|
|
150
|
+
self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device, dtype=dtype)
|
|
151
|
+
|
|
152
|
+
# position_embeds (This is a fixed tensor)
|
|
153
|
+
self.position_embeds = nn.Parameter(
|
|
154
|
+
torch.zeros(1, max_position_embeddings, embed_dim, device=device, dtype=dtype)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# encoders
|
|
158
|
+
self.encoders = nn.ModuleList(
|
|
159
|
+
[
|
|
160
|
+
CLIPEncoderLayer(embed_dim, encoder_intermediate_size, device=device, dtype=dtype)
|
|
161
|
+
for _ in range(num_encoder_layers)
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# attn_mask
|
|
166
|
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
|
167
|
+
|
|
168
|
+
# The text encoder is different to that in Stable Diffusion 1.x.
|
|
169
|
+
# It does not include final_layer_norm.
|
|
170
|
+
|
|
171
|
+
def attention_mask(self, length):
|
|
172
|
+
mask = torch.empty(length, length)
|
|
173
|
+
mask.fill_(float("-inf"))
|
|
174
|
+
mask.triu_(1)
|
|
175
|
+
return mask
|
|
176
|
+
|
|
177
|
+
def forward(self, input_ids, clip_skip=2):
|
|
178
|
+
clip_skip = max(
|
|
179
|
+
clip_skip - 1, 1
|
|
180
|
+
) # Because we did not load the last layer of the encoder, the clip_skip needs to be decreased by 1.
|
|
181
|
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
|
182
|
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
183
|
+
for encoder_id, encoder in enumerate(self.encoders):
|
|
184
|
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
185
|
+
if encoder_id + clip_skip == len(self.encoders):
|
|
186
|
+
break
|
|
187
|
+
return embeds
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def from_state_dict(
|
|
191
|
+
cls,
|
|
192
|
+
state_dict: Dict[str, torch.Tensor],
|
|
193
|
+
device: str,
|
|
194
|
+
dtype: torch.dtype,
|
|
195
|
+
embed_dim: int = 768,
|
|
196
|
+
vocab_size: int = 49408,
|
|
197
|
+
max_position_embeddings: int = 77,
|
|
198
|
+
num_encoder_layers: int = 11,
|
|
199
|
+
encoder_intermediate_size: int = 3072,
|
|
200
|
+
):
|
|
201
|
+
with no_init_weights():
|
|
202
|
+
model = torch.nn.utils.skip_init(
|
|
203
|
+
cls,
|
|
204
|
+
device=device,
|
|
205
|
+
dtype=dtype,
|
|
206
|
+
embed_dim=embed_dim,
|
|
207
|
+
vocab_size=vocab_size,
|
|
208
|
+
max_position_embeddings=max_position_embeddings,
|
|
209
|
+
num_encoder_layers=num_encoder_layers,
|
|
210
|
+
encoder_intermediate_size=encoder_intermediate_size,
|
|
211
|
+
)
|
|
212
|
+
model.load_state_dict(state_dict)
|
|
213
|
+
return model
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class SDXLTextEncoder2(PreTrainedModel):
|
|
217
|
+
converter = SDXLTextEncoder2StateDictConverter()
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
embed_dim=1280,
|
|
222
|
+
vocab_size=49408,
|
|
223
|
+
max_position_embeddings=77,
|
|
224
|
+
num_encoder_layers=32,
|
|
225
|
+
encoder_intermediate_size=5120,
|
|
226
|
+
device: str = "cuda:0",
|
|
227
|
+
dtype: torch.dtype = torch.float16,
|
|
228
|
+
):
|
|
229
|
+
super().__init__()
|
|
230
|
+
|
|
231
|
+
# token_embedding
|
|
232
|
+
self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device, dtype=dtype)
|
|
233
|
+
|
|
234
|
+
# position_embeds (This is a fixed tensor)
|
|
235
|
+
self.position_embeds = nn.Parameter(
|
|
236
|
+
torch.zeros(1, max_position_embeddings, embed_dim, device=device, dtype=dtype)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# encoders
|
|
240
|
+
self.encoders = nn.ModuleList(
|
|
241
|
+
[
|
|
242
|
+
CLIPEncoderLayer(
|
|
243
|
+
embed_dim,
|
|
244
|
+
encoder_intermediate_size,
|
|
245
|
+
num_heads=20,
|
|
246
|
+
head_dim=64,
|
|
247
|
+
use_quick_gelu=False,
|
|
248
|
+
device=device,
|
|
249
|
+
dtype=dtype,
|
|
250
|
+
)
|
|
251
|
+
for _ in range(num_encoder_layers)
|
|
252
|
+
]
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# attn_mask
|
|
256
|
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
|
257
|
+
|
|
258
|
+
# final_layer_norm
|
|
259
|
+
self.final_layer_norm = nn.LayerNorm(embed_dim, device=device, dtype=dtype)
|
|
260
|
+
|
|
261
|
+
# text_projection
|
|
262
|
+
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, device=device, dtype=dtype)
|
|
263
|
+
|
|
264
|
+
def attention_mask(self, length):
|
|
265
|
+
mask = torch.empty(length, length)
|
|
266
|
+
mask.fill_(float("-inf"))
|
|
267
|
+
mask.triu_(1)
|
|
268
|
+
return mask
|
|
269
|
+
|
|
270
|
+
def forward(self, input_ids, clip_skip=2):
|
|
271
|
+
clip_skip = max(clip_skip, 1)
|
|
272
|
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
|
273
|
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
274
|
+
for encoder_id, encoder in enumerate(self.encoders):
|
|
275
|
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
276
|
+
if encoder_id + clip_skip == len(self.encoders):
|
|
277
|
+
hidden_states = embeds
|
|
278
|
+
embeds = self.final_layer_norm(embeds)
|
|
279
|
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
|
280
|
+
pooled_embeds = self.text_projection(pooled_embeds)
|
|
281
|
+
return hidden_states, pooled_embeds
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def from_state_dict(
|
|
285
|
+
cls,
|
|
286
|
+
state_dict: Dict[str, torch.Tensor],
|
|
287
|
+
device: str,
|
|
288
|
+
dtype: torch.dtype,
|
|
289
|
+
embed_dim: int = 1280,
|
|
290
|
+
vocab_size: int = 49408,
|
|
291
|
+
max_position_embeddings: int = 77,
|
|
292
|
+
num_encoder_layers: int = 32,
|
|
293
|
+
encoder_intermediate_size: int = 5120,
|
|
294
|
+
):
|
|
295
|
+
with no_init_weights():
|
|
296
|
+
model = torch.nn.utils.skip_init(
|
|
297
|
+
cls,
|
|
298
|
+
device=device,
|
|
299
|
+
dtype=dtype,
|
|
300
|
+
embed_dim=embed_dim,
|
|
301
|
+
vocab_size=vocab_size,
|
|
302
|
+
max_position_embeddings=max_position_embeddings,
|
|
303
|
+
num_encoder_layers=num_encoder_layers,
|
|
304
|
+
encoder_intermediate_size=encoder_intermediate_size,
|
|
305
|
+
)
|
|
306
|
+
model.load_state_dict(state_dict)
|
|
307
|
+
return model
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
7
|
+
from diffsynth_engine.models.basic.unet_helper import (
|
|
8
|
+
ResnetBlock,
|
|
9
|
+
AttentionBlock,
|
|
10
|
+
PushBlock,
|
|
11
|
+
DownSampler,
|
|
12
|
+
PopBlock,
|
|
13
|
+
UpSampler,
|
|
14
|
+
)
|
|
15
|
+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
|
|
16
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
17
|
+
from diffsynth_engine.utils.constants import SDXL_UNET_CONFIG_FILE
|
|
18
|
+
from diffsynth_engine.utils import logging
|
|
19
|
+
|
|
20
|
+
logger = logging.get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
with open(SDXL_UNET_CONFIG_FILE, "r") as f:
|
|
23
|
+
config = json.load(f)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SDXLUNetStateDictConverter(StateDictConverter):
|
|
27
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
28
|
+
# architecture
|
|
29
|
+
block_types = [
|
|
30
|
+
"ResnetBlock",
|
|
31
|
+
"PushBlock",
|
|
32
|
+
"ResnetBlock",
|
|
33
|
+
"PushBlock",
|
|
34
|
+
"DownSampler",
|
|
35
|
+
"PushBlock",
|
|
36
|
+
"ResnetBlock",
|
|
37
|
+
"AttentionBlock",
|
|
38
|
+
"PushBlock",
|
|
39
|
+
"ResnetBlock",
|
|
40
|
+
"AttentionBlock",
|
|
41
|
+
"PushBlock",
|
|
42
|
+
"DownSampler",
|
|
43
|
+
"PushBlock",
|
|
44
|
+
"ResnetBlock",
|
|
45
|
+
"AttentionBlock",
|
|
46
|
+
"PushBlock",
|
|
47
|
+
"ResnetBlock",
|
|
48
|
+
"AttentionBlock",
|
|
49
|
+
"PushBlock",
|
|
50
|
+
"ResnetBlock",
|
|
51
|
+
"AttentionBlock",
|
|
52
|
+
"ResnetBlock",
|
|
53
|
+
"PopBlock",
|
|
54
|
+
"ResnetBlock",
|
|
55
|
+
"AttentionBlock",
|
|
56
|
+
"PopBlock",
|
|
57
|
+
"ResnetBlock",
|
|
58
|
+
"AttentionBlock",
|
|
59
|
+
"PopBlock",
|
|
60
|
+
"ResnetBlock",
|
|
61
|
+
"AttentionBlock",
|
|
62
|
+
"UpSampler",
|
|
63
|
+
"PopBlock",
|
|
64
|
+
"ResnetBlock",
|
|
65
|
+
"AttentionBlock",
|
|
66
|
+
"PopBlock",
|
|
67
|
+
"ResnetBlock",
|
|
68
|
+
"AttentionBlock",
|
|
69
|
+
"PopBlock",
|
|
70
|
+
"ResnetBlock",
|
|
71
|
+
"AttentionBlock",
|
|
72
|
+
"UpSampler",
|
|
73
|
+
"PopBlock",
|
|
74
|
+
"ResnetBlock",
|
|
75
|
+
"PopBlock",
|
|
76
|
+
"ResnetBlock",
|
|
77
|
+
"PopBlock",
|
|
78
|
+
"ResnetBlock",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# rename each parameter
|
|
82
|
+
name_list = sorted([name for name in state_dict])
|
|
83
|
+
rename_dict = {}
|
|
84
|
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
|
85
|
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
|
86
|
+
for name in name_list:
|
|
87
|
+
names = name.split(".")
|
|
88
|
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
|
89
|
+
pass
|
|
90
|
+
elif names[0] in ["encoder_hid_proj"]:
|
|
91
|
+
names[0] = "text_intermediate_proj"
|
|
92
|
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
|
93
|
+
if names[0] == "add_embedding":
|
|
94
|
+
names[0] = "add_time_embedding"
|
|
95
|
+
else:
|
|
96
|
+
names[0] = "time_embedding.timestep_embedder"
|
|
97
|
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
|
98
|
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
|
99
|
+
if names[0] == "mid_block":
|
|
100
|
+
names.insert(1, "0")
|
|
101
|
+
block_type = {
|
|
102
|
+
"resnets": "ResnetBlock",
|
|
103
|
+
"attentions": "AttentionBlock",
|
|
104
|
+
"downsamplers": "DownSampler",
|
|
105
|
+
"upsamplers": "UpSampler",
|
|
106
|
+
}[names[2]]
|
|
107
|
+
block_type_with_id = ".".join(names[:4])
|
|
108
|
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
|
109
|
+
block_id[block_type] += 1
|
|
110
|
+
last_block_type_with_id[block_type] = block_type_with_id
|
|
111
|
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
|
112
|
+
block_id[block_type] += 1
|
|
113
|
+
block_type_with_id = ".".join(names[:4])
|
|
114
|
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
|
115
|
+
if "ff" in names:
|
|
116
|
+
ff_index = names.index("ff")
|
|
117
|
+
component = ".".join(names[ff_index : ff_index + 3])
|
|
118
|
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
|
119
|
+
names = names[:ff_index] + [component] + names[ff_index + 3 :]
|
|
120
|
+
if "to_out" in names:
|
|
121
|
+
names.pop(names.index("to_out") + 1)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Unknown parameters: {name}")
|
|
124
|
+
rename_dict[name] = ".".join(names)
|
|
125
|
+
|
|
126
|
+
# convert state_dict
|
|
127
|
+
state_dict_ = {}
|
|
128
|
+
for name, param in state_dict.items():
|
|
129
|
+
if ".proj_in." in name or ".proj_out." in name:
|
|
130
|
+
param = param.squeeze()
|
|
131
|
+
state_dict_[rename_dict[name]] = param
|
|
132
|
+
if "text_intermediate_proj.weight" in state_dict_:
|
|
133
|
+
return state_dict_, {"is_kolors": True}
|
|
134
|
+
else:
|
|
135
|
+
return state_dict_
|
|
136
|
+
|
|
137
|
+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
138
|
+
rename_dict = config["civitai"]["rename_dict"]
|
|
139
|
+
state_dict_ = {}
|
|
140
|
+
for name, param in state_dict.items():
|
|
141
|
+
name, suffix = split_suffix(name)
|
|
142
|
+
if name in rename_dict:
|
|
143
|
+
if ".proj_in." in name or ".proj_out." in name:
|
|
144
|
+
param = param.squeeze()
|
|
145
|
+
name_ = rename_dict[name] + suffix
|
|
146
|
+
state_dict_[name_] = param
|
|
147
|
+
return state_dict_
|
|
148
|
+
|
|
149
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
150
|
+
if "model.diffusion_model.input_blocks.0.0.weight" in state_dict:
|
|
151
|
+
state_dict = self._from_civitai(state_dict)
|
|
152
|
+
logger.info("use civitai format state dict")
|
|
153
|
+
elif "down_blocks.0.resnets.0.conv1.weight" in state_dict:
|
|
154
|
+
state_dict = self._from_diffusers(state_dict)
|
|
155
|
+
logger.info("use diffusers format state dict")
|
|
156
|
+
else:
|
|
157
|
+
logger.info("user diffsynth format state dict")
|
|
158
|
+
return state_dict
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class SDXLUNet(PreTrainedModel):
|
|
162
|
+
converter = SDXLUNetStateDictConverter()
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
is_kolors: bool = False,
|
|
167
|
+
device: str = "cuda:0",
|
|
168
|
+
dtype: torch.dtype = torch.float16,
|
|
169
|
+
use_gradient_checkpointing: bool = False,
|
|
170
|
+
):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
|
173
|
+
self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
|
|
174
|
+
self.add_time_embedding = nn.Sequential(
|
|
175
|
+
nn.Linear(5632 if is_kolors else 2816, 1280, device=device, dtype=dtype),
|
|
176
|
+
nn.SiLU(),
|
|
177
|
+
nn.Linear(1280, 1280, device=device, dtype=dtype),
|
|
178
|
+
)
|
|
179
|
+
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
180
|
+
self.text_intermediate_proj = nn.Linear(4096, 2048, device=device, dtype=dtype) if is_kolors else None
|
|
181
|
+
|
|
182
|
+
self.blocks = nn.ModuleList(
|
|
183
|
+
[
|
|
184
|
+
# DownBlock2D
|
|
185
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
186
|
+
PushBlock(),
|
|
187
|
+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
|
|
188
|
+
PushBlock(),
|
|
189
|
+
DownSampler(320, device=device, dtype=dtype),
|
|
190
|
+
PushBlock(),
|
|
191
|
+
# CrossAttnDownBlock2D
|
|
192
|
+
ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
|
|
193
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
194
|
+
PushBlock(),
|
|
195
|
+
ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
|
|
196
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
197
|
+
PushBlock(),
|
|
198
|
+
DownSampler(640, device=device, dtype=dtype),
|
|
199
|
+
PushBlock(),
|
|
200
|
+
# CrossAttnDownBlock2D
|
|
201
|
+
ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
|
|
202
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
203
|
+
PushBlock(),
|
|
204
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
205
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
206
|
+
PushBlock(),
|
|
207
|
+
# UNetMidBlock2DCrossAttn
|
|
208
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
209
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
210
|
+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
|
|
211
|
+
# CrossAttnUpBlock2D
|
|
212
|
+
PopBlock(),
|
|
213
|
+
ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
|
|
214
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
215
|
+
PopBlock(),
|
|
216
|
+
ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
|
|
217
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
218
|
+
PopBlock(),
|
|
219
|
+
ResnetBlock(1920, 1280, 1280, device=device, dtype=dtype),
|
|
220
|
+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
|
|
221
|
+
UpSampler(1280, device=device, dtype=dtype),
|
|
222
|
+
# CrossAttnUpBlock2D
|
|
223
|
+
PopBlock(),
|
|
224
|
+
ResnetBlock(1920, 640, 1280, device=device, dtype=dtype),
|
|
225
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
226
|
+
PopBlock(),
|
|
227
|
+
ResnetBlock(1280, 640, 1280, device=device, dtype=dtype),
|
|
228
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
229
|
+
PopBlock(),
|
|
230
|
+
ResnetBlock(960, 640, 1280, device=device, dtype=dtype),
|
|
231
|
+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
|
|
232
|
+
UpSampler(640, device=device, dtype=dtype),
|
|
233
|
+
# UpBlock2D
|
|
234
|
+
PopBlock(),
|
|
235
|
+
ResnetBlock(960, 320, 1280, device=device, dtype=dtype),
|
|
236
|
+
PopBlock(),
|
|
237
|
+
ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
|
|
238
|
+
PopBlock(),
|
|
239
|
+
ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
self.conv_norm_out = nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5, device=device, dtype=dtype)
|
|
244
|
+
self.conv_act = nn.SiLU()
|
|
245
|
+
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1, device=device, dtype=dtype)
|
|
246
|
+
|
|
247
|
+
self.is_kolors = is_kolors
|
|
248
|
+
|
|
249
|
+
def forward(self, x, timestep, context, y, **kwargs):
|
|
250
|
+
# 1. time embedding
|
|
251
|
+
t_emb = self.time_embedding(timestep, dtype=x.dtype)
|
|
252
|
+
## add embedding
|
|
253
|
+
add_embeds = self.add_time_embedding(y)
|
|
254
|
+
|
|
255
|
+
time_emb = t_emb + add_embeds
|
|
256
|
+
|
|
257
|
+
# 2. pre-process
|
|
258
|
+
hidden_states = self.conv_in(x)
|
|
259
|
+
text_emb = context if self.text_intermediate_proj is None else self.text_intermediate_proj(context)
|
|
260
|
+
res_stack = [hidden_states]
|
|
261
|
+
|
|
262
|
+
# 3. blocks
|
|
263
|
+
def create_custom_forward(module):
|
|
264
|
+
def custom_forward(*inputs):
|
|
265
|
+
return module(*inputs)
|
|
266
|
+
|
|
267
|
+
return custom_forward
|
|
268
|
+
|
|
269
|
+
for i, block in enumerate(self.blocks):
|
|
270
|
+
if (
|
|
271
|
+
self.training
|
|
272
|
+
and self.use_gradient_checkpointing
|
|
273
|
+
and not (isinstance(block, PushBlock) or isinstance(block, PopBlock))
|
|
274
|
+
):
|
|
275
|
+
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
|
276
|
+
create_custom_forward(block),
|
|
277
|
+
hidden_states,
|
|
278
|
+
time_emb,
|
|
279
|
+
text_emb,
|
|
280
|
+
res_stack,
|
|
281
|
+
use_reentrant=False,
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
hidden_states, time_emb, text_emb, res_stack = block(
|
|
285
|
+
hidden_states,
|
|
286
|
+
time_emb,
|
|
287
|
+
text_emb,
|
|
288
|
+
res_stack,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# 4. output
|
|
292
|
+
hidden_states = self.conv_norm_out(hidden_states)
|
|
293
|
+
hidden_states = self.conv_act(hidden_states)
|
|
294
|
+
hidden_states = self.conv_out(hidden_states)
|
|
295
|
+
|
|
296
|
+
return hidden_states
|
|
297
|
+
|
|
298
|
+
@classmethod
|
|
299
|
+
def from_state_dict(
|
|
300
|
+
cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, is_kolors: bool = False
|
|
301
|
+
):
|
|
302
|
+
with no_init_weights():
|
|
303
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, is_kolors=is_kolors)
|
|
304
|
+
model.load_state_dict(state_dict, assign=True)
|
|
305
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
306
|
+
return model
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from diffsynth_engine.models.components.vae import VAEDecoder, VAEEncoder
|
|
5
|
+
from diffsynth_engine.models.utils import no_init_weights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SDXLVAEEncoder(VAEEncoder):
|
|
9
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
|
|
10
|
+
super().__init__(
|
|
11
|
+
latent_channels=4, scaling_factor=0.13025, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
16
|
+
with no_init_weights():
|
|
17
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
18
|
+
model.load_state_dict(state_dict)
|
|
19
|
+
return model
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SDXLVAEDecoder(VAEDecoder):
|
|
23
|
+
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
|
|
24
|
+
super().__init__(
|
|
25
|
+
latent_channels=4,
|
|
26
|
+
scaling_factor=0.13025,
|
|
27
|
+
shift_factor=0,
|
|
28
|
+
use_post_quant_conv=True,
|
|
29
|
+
device=device,
|
|
30
|
+
dtype=dtype,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
35
|
+
with no_init_weights():
|
|
36
|
+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
|
|
37
|
+
model.load_state_dict(state_dict)
|
|
38
|
+
return model
|