diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
from .attention import Attention
|
|
2
|
+
from einops import repeat, rearrange
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
|
8
|
+
|
|
9
|
+
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
12
|
+
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
|
13
|
+
self.rotary_emb_on_k = rotary_emb_on_k
|
|
14
|
+
self.k_cache, self.v_cache = [], []
|
|
15
|
+
|
|
16
|
+
def reshape_for_broadcast(self, freqs_cis, x):
|
|
17
|
+
ndim = x.ndim
|
|
18
|
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
19
|
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
20
|
+
|
|
21
|
+
def rotate_half(self, x):
|
|
22
|
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
23
|
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
24
|
+
|
|
25
|
+
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
|
26
|
+
xk_out = None
|
|
27
|
+
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
|
28
|
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
29
|
+
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
|
30
|
+
if xk is not None:
|
|
31
|
+
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
|
32
|
+
return xq_out, xk_out
|
|
33
|
+
|
|
34
|
+
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
|
35
|
+
# norm
|
|
36
|
+
q = self.q_norm(q)
|
|
37
|
+
k = self.k_norm(k)
|
|
38
|
+
|
|
39
|
+
# RoPE
|
|
40
|
+
if self.rotary_emb_on_k:
|
|
41
|
+
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
|
42
|
+
else:
|
|
43
|
+
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
|
44
|
+
|
|
45
|
+
if to_cache:
|
|
46
|
+
self.k_cache.append(k)
|
|
47
|
+
self.v_cache.append(v)
|
|
48
|
+
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
|
49
|
+
k = torch.concat([k] + self.k_cache, dim=2)
|
|
50
|
+
v = torch.concat([v] + self.v_cache, dim=2)
|
|
51
|
+
self.k_cache, self.v_cache = [], []
|
|
52
|
+
return q, k, v
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class FP32_Layernorm(torch.nn.LayerNorm):
|
|
56
|
+
def forward(self, inputs):
|
|
57
|
+
origin_dtype = inputs.dtype
|
|
58
|
+
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FP32_SiLU(torch.nn.SiLU):
|
|
62
|
+
def forward(self, inputs):
|
|
63
|
+
origin_dtype = inputs.dtype
|
|
64
|
+
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class HunyuanDiTFinalLayer(torch.nn.Module):
|
|
68
|
+
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
|
71
|
+
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
72
|
+
self.adaLN_modulation = torch.nn.Sequential(
|
|
73
|
+
FP32_SiLU(),
|
|
74
|
+
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def modulate(self, x, shift, scale):
|
|
78
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
79
|
+
|
|
80
|
+
def forward(self, hidden_states, condition_emb):
|
|
81
|
+
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
|
82
|
+
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
|
83
|
+
hidden_states = self.linear(hidden_states)
|
|
84
|
+
return hidden_states
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class HunyuanDiTBlock(torch.nn.Module):
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
hidden_dim=1408,
|
|
92
|
+
condition_dim=1408,
|
|
93
|
+
num_heads=16,
|
|
94
|
+
mlp_ratio=4.3637,
|
|
95
|
+
text_dim=1024,
|
|
96
|
+
skip_connection=False
|
|
97
|
+
):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
100
|
+
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
|
101
|
+
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
102
|
+
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
103
|
+
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
|
104
|
+
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
|
105
|
+
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
|
106
|
+
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
|
107
|
+
self.mlp = torch.nn.Sequential(
|
|
108
|
+
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
|
109
|
+
torch.nn.GELU(approximate="tanh"),
|
|
110
|
+
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
|
111
|
+
)
|
|
112
|
+
if skip_connection:
|
|
113
|
+
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
|
114
|
+
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
|
115
|
+
else:
|
|
116
|
+
self.skip_norm, self.skip_linear = None, None
|
|
117
|
+
|
|
118
|
+
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
|
119
|
+
# Long Skip Connection
|
|
120
|
+
if self.skip_norm is not None and self.skip_linear is not None:
|
|
121
|
+
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
|
122
|
+
hidden_states = self.skip_norm(hidden_states)
|
|
123
|
+
hidden_states = self.skip_linear(hidden_states)
|
|
124
|
+
|
|
125
|
+
# Self-Attention
|
|
126
|
+
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
|
127
|
+
attn_input = self.norm1(hidden_states) + shift_msa
|
|
128
|
+
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
|
129
|
+
|
|
130
|
+
# Cross-Attention
|
|
131
|
+
attn_input = self.norm3(hidden_states)
|
|
132
|
+
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
|
133
|
+
|
|
134
|
+
# FFN Layer
|
|
135
|
+
mlp_input = self.norm2(hidden_states)
|
|
136
|
+
hidden_states = hidden_states + self.mlp(mlp_input)
|
|
137
|
+
return hidden_states
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class AttentionPool(torch.nn.Module):
|
|
141
|
+
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
|
144
|
+
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
145
|
+
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
146
|
+
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
|
147
|
+
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
|
148
|
+
self.num_heads = num_heads
|
|
149
|
+
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
|
152
|
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
|
153
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
|
154
|
+
x, _ = torch.nn.functional.multi_head_attention_forward(
|
|
155
|
+
query=x[:1], key=x, value=x,
|
|
156
|
+
embed_dim_to_check=x.shape[-1],
|
|
157
|
+
num_heads=self.num_heads,
|
|
158
|
+
q_proj_weight=self.q_proj.weight,
|
|
159
|
+
k_proj_weight=self.k_proj.weight,
|
|
160
|
+
v_proj_weight=self.v_proj.weight,
|
|
161
|
+
in_proj_weight=None,
|
|
162
|
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
163
|
+
bias_k=None,
|
|
164
|
+
bias_v=None,
|
|
165
|
+
add_zero_attn=False,
|
|
166
|
+
dropout_p=0,
|
|
167
|
+
out_proj_weight=self.c_proj.weight,
|
|
168
|
+
out_proj_bias=self.c_proj.bias,
|
|
169
|
+
use_separate_proj_weight=True,
|
|
170
|
+
training=self.training,
|
|
171
|
+
need_weights=False
|
|
172
|
+
)
|
|
173
|
+
return x.squeeze(0)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class PatchEmbed(torch.nn.Module):
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
patch_size=(2, 2),
|
|
180
|
+
in_chans=4,
|
|
181
|
+
embed_dim=1408,
|
|
182
|
+
bias=True,
|
|
183
|
+
):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
186
|
+
|
|
187
|
+
def forward(self, x):
|
|
188
|
+
x = self.proj(x)
|
|
189
|
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
190
|
+
return x
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
|
194
|
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
195
|
+
if not repeat_only:
|
|
196
|
+
half = dim // 2
|
|
197
|
+
freqs = torch.exp(
|
|
198
|
+
-math.log(max_period)
|
|
199
|
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
|
200
|
+
/ half
|
|
201
|
+
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
|
202
|
+
args = t[:, None].float() * freqs[None]
|
|
203
|
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
204
|
+
if dim % 2:
|
|
205
|
+
embedding = torch.cat(
|
|
206
|
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
embedding = repeat(t, "b -> b d", d=dim)
|
|
210
|
+
return embedding
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class TimestepEmbedder(torch.nn.Module):
|
|
214
|
+
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.mlp = torch.nn.Sequential(
|
|
217
|
+
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
218
|
+
torch.nn.SiLU(),
|
|
219
|
+
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
|
220
|
+
)
|
|
221
|
+
self.frequency_embedding_size = frequency_embedding_size
|
|
222
|
+
|
|
223
|
+
def forward(self, t):
|
|
224
|
+
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
|
225
|
+
t_emb = self.mlp(t_freq)
|
|
226
|
+
return t_emb
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class HunyuanDiT(torch.nn.Module):
|
|
230
|
+
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
|
231
|
+
super().__init__()
|
|
232
|
+
|
|
233
|
+
# Embedders
|
|
234
|
+
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
|
235
|
+
self.t5_embedder = torch.nn.Sequential(
|
|
236
|
+
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
|
237
|
+
FP32_SiLU(),
|
|
238
|
+
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
|
239
|
+
)
|
|
240
|
+
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
|
241
|
+
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
|
242
|
+
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
|
243
|
+
self.timestep_embedder = TimestepEmbedder()
|
|
244
|
+
self.extra_embedder = torch.nn.Sequential(
|
|
245
|
+
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
|
246
|
+
FP32_SiLU(),
|
|
247
|
+
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Transformer blocks
|
|
251
|
+
self.num_layers_down = num_layers_down
|
|
252
|
+
self.num_layers_up = num_layers_up
|
|
253
|
+
self.blocks = torch.nn.ModuleList(
|
|
254
|
+
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
|
255
|
+
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Output layers
|
|
259
|
+
self.final_layer = HunyuanDiTFinalLayer()
|
|
260
|
+
self.out_channels = out_channels
|
|
261
|
+
|
|
262
|
+
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
|
263
|
+
text_emb_mask = text_emb_mask.bool()
|
|
264
|
+
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
|
265
|
+
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
|
266
|
+
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
|
267
|
+
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
|
268
|
+
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
|
269
|
+
return text_emb
|
|
270
|
+
|
|
271
|
+
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
|
272
|
+
# Text embedding
|
|
273
|
+
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
|
274
|
+
|
|
275
|
+
# Timestep embedding
|
|
276
|
+
timestep_emb = self.timestep_embedder(timestep)
|
|
277
|
+
|
|
278
|
+
# Size embedding
|
|
279
|
+
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
|
280
|
+
size_emb = size_emb.view(-1, 6 * 256)
|
|
281
|
+
|
|
282
|
+
# Style embedding
|
|
283
|
+
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
|
284
|
+
|
|
285
|
+
# Concatenate all extra vectors
|
|
286
|
+
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
|
287
|
+
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
|
288
|
+
|
|
289
|
+
return condition_emb
|
|
290
|
+
|
|
291
|
+
def unpatchify(self, x, h, w):
|
|
292
|
+
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
|
293
|
+
|
|
294
|
+
def build_mask(self, data, is_bound):
|
|
295
|
+
_, _, H, W = data.shape
|
|
296
|
+
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
|
297
|
+
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
|
298
|
+
border_width = (H + W) // 4
|
|
299
|
+
pad = torch.ones_like(h) * border_width
|
|
300
|
+
mask = torch.stack([
|
|
301
|
+
pad if is_bound[0] else h + 1,
|
|
302
|
+
pad if is_bound[1] else H - h,
|
|
303
|
+
pad if is_bound[2] else w + 1,
|
|
304
|
+
pad if is_bound[3] else W - w
|
|
305
|
+
]).min(dim=0).values
|
|
306
|
+
mask = mask.clip(1, border_width)
|
|
307
|
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
|
308
|
+
mask = rearrange(mask, "H W -> 1 H W")
|
|
309
|
+
return mask
|
|
310
|
+
|
|
311
|
+
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
|
312
|
+
B, C, H, W = hidden_states.shape
|
|
313
|
+
|
|
314
|
+
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
|
315
|
+
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
|
316
|
+
|
|
317
|
+
# Split tasks
|
|
318
|
+
tasks = []
|
|
319
|
+
for h in range(0, H, tile_stride):
|
|
320
|
+
for w in range(0, W, tile_stride):
|
|
321
|
+
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
322
|
+
continue
|
|
323
|
+
h_, w_ = h + tile_size, w + tile_size
|
|
324
|
+
if h_ > H: h, h_ = H - tile_size, H
|
|
325
|
+
if w_ > W: w, w_ = W - tile_size, W
|
|
326
|
+
tasks.append((h, h_, w, w_))
|
|
327
|
+
|
|
328
|
+
# Run
|
|
329
|
+
for hl, hr, wl, wr in tasks:
|
|
330
|
+
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
|
331
|
+
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
|
332
|
+
if residual is not None:
|
|
333
|
+
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
|
334
|
+
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
|
335
|
+
else:
|
|
336
|
+
residual_batch = None
|
|
337
|
+
|
|
338
|
+
# Forward
|
|
339
|
+
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
|
340
|
+
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
|
341
|
+
|
|
342
|
+
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
|
343
|
+
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
|
344
|
+
weight[:, :, hl:hr, wl:wr] += mask
|
|
345
|
+
values /= weight
|
|
346
|
+
return values
|
|
347
|
+
|
|
348
|
+
def forward(
|
|
349
|
+
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
|
350
|
+
tiled=False, tile_size=64, tile_stride=32,
|
|
351
|
+
to_cache=False,
|
|
352
|
+
use_gradient_checkpointing=False,
|
|
353
|
+
):
|
|
354
|
+
# Embeddings
|
|
355
|
+
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
|
356
|
+
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
|
357
|
+
|
|
358
|
+
# Input
|
|
359
|
+
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
|
360
|
+
hidden_states = self.patch_embedder(hidden_states)
|
|
361
|
+
|
|
362
|
+
# Blocks
|
|
363
|
+
def create_custom_forward(module):
|
|
364
|
+
def custom_forward(*inputs):
|
|
365
|
+
return module(*inputs)
|
|
366
|
+
return custom_forward
|
|
367
|
+
if tiled:
|
|
368
|
+
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
|
369
|
+
residuals = []
|
|
370
|
+
for block_id, block in enumerate(self.blocks):
|
|
371
|
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
372
|
+
hidden_states = self.tiled_block_forward(
|
|
373
|
+
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
374
|
+
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
|
375
|
+
tile_size=tile_size, tile_stride=tile_stride
|
|
376
|
+
)
|
|
377
|
+
if block_id < self.num_layers_down - 2:
|
|
378
|
+
residuals.append(hidden_states)
|
|
379
|
+
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
|
380
|
+
else:
|
|
381
|
+
residuals = []
|
|
382
|
+
for block_id, block in enumerate(self.blocks):
|
|
383
|
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
|
384
|
+
if self.training and use_gradient_checkpointing:
|
|
385
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
386
|
+
create_custom_forward(block),
|
|
387
|
+
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
|
388
|
+
use_reentrant=False,
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
|
392
|
+
if block_id < self.num_layers_down - 2:
|
|
393
|
+
residuals.append(hidden_states)
|
|
394
|
+
|
|
395
|
+
# Output
|
|
396
|
+
hidden_states = self.final_layer(hidden_states, condition_emb)
|
|
397
|
+
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
|
398
|
+
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
|
399
|
+
return hidden_states
|
|
400
|
+
|
|
401
|
+
@staticmethod
|
|
402
|
+
def state_dict_converter():
|
|
403
|
+
return HunyuanDiTStateDictConverter()
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class HunyuanDiTStateDictConverter():
|
|
408
|
+
def __init__(self):
|
|
409
|
+
pass
|
|
410
|
+
|
|
411
|
+
def from_diffusers(self, state_dict):
|
|
412
|
+
state_dict_ = {}
|
|
413
|
+
for name, param in state_dict.items():
|
|
414
|
+
name_ = name
|
|
415
|
+
name_ = name_.replace(".default_modulation.", ".modulation.")
|
|
416
|
+
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
|
417
|
+
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
|
418
|
+
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
|
419
|
+
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
|
420
|
+
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
|
421
|
+
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
|
422
|
+
name_ = name_.replace(".q_proj.", ".to_q.")
|
|
423
|
+
name_ = name_.replace(".out_proj.", ".to_out.")
|
|
424
|
+
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
|
425
|
+
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
|
426
|
+
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
|
427
|
+
name_ = name_.replace("pooler.", "t5_pooler.")
|
|
428
|
+
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
|
429
|
+
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
|
430
|
+
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
|
431
|
+
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
|
432
|
+
if ".kv_proj." in name_:
|
|
433
|
+
param_k = param[:param.shape[0]//2]
|
|
434
|
+
param_v = param[param.shape[0]//2:]
|
|
435
|
+
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
|
436
|
+
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
|
437
|
+
elif ".Wqkv." in name_:
|
|
438
|
+
param_q = param[:param.shape[0]//3]
|
|
439
|
+
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
|
440
|
+
param_v = param[param.shape[0]//3*2:]
|
|
441
|
+
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
|
442
|
+
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
|
443
|
+
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
|
444
|
+
elif "style_embedder" in name_:
|
|
445
|
+
state_dict_[name_] = param.squeeze()
|
|
446
|
+
else:
|
|
447
|
+
state_dict_[name_] = param
|
|
448
|
+
return state_dict_
|
|
449
|
+
|
|
450
|
+
def from_civitai(self, state_dict):
|
|
451
|
+
return self.from_diffusers(state_dict)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class HunyuanDiTCLIPTextEncoder(BertModel):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
config = BertConfig(
|
|
9
|
+
_name_or_path = "",
|
|
10
|
+
architectures = ["BertModel"],
|
|
11
|
+
attention_probs_dropout_prob = 0.1,
|
|
12
|
+
bos_token_id = 0,
|
|
13
|
+
classifier_dropout = None,
|
|
14
|
+
directionality = "bidi",
|
|
15
|
+
eos_token_id = 2,
|
|
16
|
+
hidden_act = "gelu",
|
|
17
|
+
hidden_dropout_prob = 0.1,
|
|
18
|
+
hidden_size = 1024,
|
|
19
|
+
initializer_range = 0.02,
|
|
20
|
+
intermediate_size = 4096,
|
|
21
|
+
layer_norm_eps = 1e-12,
|
|
22
|
+
max_position_embeddings = 512,
|
|
23
|
+
model_type = "bert",
|
|
24
|
+
num_attention_heads = 16,
|
|
25
|
+
num_hidden_layers = 24,
|
|
26
|
+
output_past = True,
|
|
27
|
+
pad_token_id = 0,
|
|
28
|
+
pooler_fc_size = 768,
|
|
29
|
+
pooler_num_attention_heads = 12,
|
|
30
|
+
pooler_num_fc_layers = 3,
|
|
31
|
+
pooler_size_per_head = 128,
|
|
32
|
+
pooler_type = "first_token_transform",
|
|
33
|
+
position_embedding_type = "absolute",
|
|
34
|
+
torch_dtype = "float32",
|
|
35
|
+
transformers_version = "4.37.2",
|
|
36
|
+
type_vocab_size = 2,
|
|
37
|
+
use_cache = True,
|
|
38
|
+
vocab_size = 47020
|
|
39
|
+
)
|
|
40
|
+
super().__init__(config, add_pooling_layer=False)
|
|
41
|
+
self.eval()
|
|
42
|
+
|
|
43
|
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
44
|
+
input_shape = input_ids.size()
|
|
45
|
+
|
|
46
|
+
batch_size, seq_length = input_shape
|
|
47
|
+
device = input_ids.device
|
|
48
|
+
|
|
49
|
+
past_key_values_length = 0
|
|
50
|
+
|
|
51
|
+
if attention_mask is None:
|
|
52
|
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
|
53
|
+
|
|
54
|
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
55
|
+
|
|
56
|
+
embedding_output = self.embeddings(
|
|
57
|
+
input_ids=input_ids,
|
|
58
|
+
position_ids=None,
|
|
59
|
+
token_type_ids=None,
|
|
60
|
+
inputs_embeds=None,
|
|
61
|
+
past_key_values_length=0,
|
|
62
|
+
)
|
|
63
|
+
encoder_outputs = self.encoder(
|
|
64
|
+
embedding_output,
|
|
65
|
+
attention_mask=extended_attention_mask,
|
|
66
|
+
head_mask=None,
|
|
67
|
+
encoder_hidden_states=None,
|
|
68
|
+
encoder_attention_mask=None,
|
|
69
|
+
past_key_values=None,
|
|
70
|
+
use_cache=False,
|
|
71
|
+
output_attentions=False,
|
|
72
|
+
output_hidden_states=True,
|
|
73
|
+
return_dict=True,
|
|
74
|
+
)
|
|
75
|
+
all_hidden_states = encoder_outputs.hidden_states
|
|
76
|
+
prompt_emb = all_hidden_states[-clip_skip]
|
|
77
|
+
if clip_skip > 1:
|
|
78
|
+
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
|
79
|
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
80
|
+
return prompt_emb
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def state_dict_converter():
|
|
84
|
+
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
config = T5Config(
|
|
91
|
+
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
|
92
|
+
architectures = ["MT5ForConditionalGeneration"],
|
|
93
|
+
classifier_dropout = 0.0,
|
|
94
|
+
d_ff = 5120,
|
|
95
|
+
d_kv = 64,
|
|
96
|
+
d_model = 2048,
|
|
97
|
+
decoder_start_token_id = 0,
|
|
98
|
+
dense_act_fn = "gelu_new",
|
|
99
|
+
dropout_rate = 0.1,
|
|
100
|
+
eos_token_id = 1,
|
|
101
|
+
feed_forward_proj = "gated-gelu",
|
|
102
|
+
initializer_factor = 1.0,
|
|
103
|
+
is_encoder_decoder = True,
|
|
104
|
+
is_gated_act = True,
|
|
105
|
+
layer_norm_epsilon = 1e-06,
|
|
106
|
+
model_type = "t5",
|
|
107
|
+
num_decoder_layers = 24,
|
|
108
|
+
num_heads = 32,
|
|
109
|
+
num_layers = 24,
|
|
110
|
+
output_past = True,
|
|
111
|
+
pad_token_id = 0,
|
|
112
|
+
relative_attention_max_distance = 128,
|
|
113
|
+
relative_attention_num_buckets = 32,
|
|
114
|
+
tie_word_embeddings = False,
|
|
115
|
+
tokenizer_class = "T5Tokenizer",
|
|
116
|
+
transformers_version = "4.37.2",
|
|
117
|
+
use_cache = True,
|
|
118
|
+
vocab_size = 250112
|
|
119
|
+
)
|
|
120
|
+
super().__init__(config)
|
|
121
|
+
self.eval()
|
|
122
|
+
|
|
123
|
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
|
124
|
+
outputs = super().forward(
|
|
125
|
+
input_ids=input_ids,
|
|
126
|
+
attention_mask=attention_mask,
|
|
127
|
+
output_hidden_states=True,
|
|
128
|
+
)
|
|
129
|
+
prompt_emb = outputs.hidden_states[-clip_skip]
|
|
130
|
+
if clip_skip > 1:
|
|
131
|
+
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
|
132
|
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
|
133
|
+
return prompt_emb
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def state_dict_converter():
|
|
137
|
+
return HunyuanDiTT5TextEncoderStateDictConverter()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
|
142
|
+
def __init__(self):
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
def from_diffusers(self, state_dict):
|
|
146
|
+
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
|
147
|
+
return state_dict_
|
|
148
|
+
|
|
149
|
+
def from_civitai(self, state_dict):
|
|
150
|
+
return self.from_diffusers(state_dict)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class HunyuanDiTT5TextEncoderStateDictConverter():
|
|
154
|
+
def __init__(self):
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def from_diffusers(self, state_dict):
|
|
158
|
+
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
|
159
|
+
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
|
160
|
+
return state_dict_
|
|
161
|
+
|
|
162
|
+
def from_civitai(self, state_dict):
|
|
163
|
+
return self.from_diffusers(state_dict)
|