diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,451 @@
1
+ from .attention import Attention
2
+ from einops import repeat, rearrange
3
+ import math
4
+ import torch
5
+
6
+
7
+ class HunyuanDiTRotaryEmbedding(torch.nn.Module):
8
+
9
+ def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
10
+ super().__init__()
11
+ self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
12
+ self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
13
+ self.rotary_emb_on_k = rotary_emb_on_k
14
+ self.k_cache, self.v_cache = [], []
15
+
16
+ def reshape_for_broadcast(self, freqs_cis, x):
17
+ ndim = x.ndim
18
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
19
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
20
+
21
+ def rotate_half(self, x):
22
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
23
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
24
+
25
+ def apply_rotary_emb(self, xq, xk, freqs_cis):
26
+ xk_out = None
27
+ cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
28
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
29
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
30
+ if xk is not None:
31
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
32
+ return xq_out, xk_out
33
+
34
+ def forward(self, q, k, v, freqs_cis_img, to_cache=False):
35
+ # norm
36
+ q = self.q_norm(q)
37
+ k = self.k_norm(k)
38
+
39
+ # RoPE
40
+ if self.rotary_emb_on_k:
41
+ q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
42
+ else:
43
+ q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
44
+
45
+ if to_cache:
46
+ self.k_cache.append(k)
47
+ self.v_cache.append(v)
48
+ elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
49
+ k = torch.concat([k] + self.k_cache, dim=2)
50
+ v = torch.concat([v] + self.v_cache, dim=2)
51
+ self.k_cache, self.v_cache = [], []
52
+ return q, k, v
53
+
54
+
55
+ class FP32_Layernorm(torch.nn.LayerNorm):
56
+ def forward(self, inputs):
57
+ origin_dtype = inputs.dtype
58
+ return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
59
+
60
+
61
+ class FP32_SiLU(torch.nn.SiLU):
62
+ def forward(self, inputs):
63
+ origin_dtype = inputs.dtype
64
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
65
+
66
+
67
+ class HunyuanDiTFinalLayer(torch.nn.Module):
68
+ def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
69
+ super().__init__()
70
+ self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
71
+ self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
72
+ self.adaLN_modulation = torch.nn.Sequential(
73
+ FP32_SiLU(),
74
+ torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
75
+ )
76
+
77
+ def modulate(self, x, shift, scale):
78
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
79
+
80
+ def forward(self, hidden_states, condition_emb):
81
+ shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
82
+ hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
83
+ hidden_states = self.linear(hidden_states)
84
+ return hidden_states
85
+
86
+
87
+ class HunyuanDiTBlock(torch.nn.Module):
88
+
89
+ def __init__(
90
+ self,
91
+ hidden_dim=1408,
92
+ condition_dim=1408,
93
+ num_heads=16,
94
+ mlp_ratio=4.3637,
95
+ text_dim=1024,
96
+ skip_connection=False
97
+ ):
98
+ super().__init__()
99
+ self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
100
+ self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
101
+ self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
102
+ self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
103
+ self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
104
+ self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
105
+ self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
106
+ self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
107
+ self.mlp = torch.nn.Sequential(
108
+ torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
109
+ torch.nn.GELU(approximate="tanh"),
110
+ torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
111
+ )
112
+ if skip_connection:
113
+ self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
114
+ self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
115
+ else:
116
+ self.skip_norm, self.skip_linear = None, None
117
+
118
+ def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
119
+ # Long Skip Connection
120
+ if self.skip_norm is not None and self.skip_linear is not None:
121
+ hidden_states = torch.cat([hidden_states, residual], dim=-1)
122
+ hidden_states = self.skip_norm(hidden_states)
123
+ hidden_states = self.skip_linear(hidden_states)
124
+
125
+ # Self-Attention
126
+ shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
127
+ attn_input = self.norm1(hidden_states) + shift_msa
128
+ hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
129
+
130
+ # Cross-Attention
131
+ attn_input = self.norm3(hidden_states)
132
+ hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
133
+
134
+ # FFN Layer
135
+ mlp_input = self.norm2(hidden_states)
136
+ hidden_states = hidden_states + self.mlp(mlp_input)
137
+ return hidden_states
138
+
139
+
140
+ class AttentionPool(torch.nn.Module):
141
+ def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
142
+ super().__init__()
143
+ self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
144
+ self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
145
+ self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
146
+ self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
147
+ self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
148
+ self.num_heads = num_heads
149
+
150
+ def forward(self, x):
151
+ x = x.permute(1, 0, 2) # NLC -> LNC
152
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
153
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
154
+ x, _ = torch.nn.functional.multi_head_attention_forward(
155
+ query=x[:1], key=x, value=x,
156
+ embed_dim_to_check=x.shape[-1],
157
+ num_heads=self.num_heads,
158
+ q_proj_weight=self.q_proj.weight,
159
+ k_proj_weight=self.k_proj.weight,
160
+ v_proj_weight=self.v_proj.weight,
161
+ in_proj_weight=None,
162
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
163
+ bias_k=None,
164
+ bias_v=None,
165
+ add_zero_attn=False,
166
+ dropout_p=0,
167
+ out_proj_weight=self.c_proj.weight,
168
+ out_proj_bias=self.c_proj.bias,
169
+ use_separate_proj_weight=True,
170
+ training=self.training,
171
+ need_weights=False
172
+ )
173
+ return x.squeeze(0)
174
+
175
+
176
+ class PatchEmbed(torch.nn.Module):
177
+ def __init__(
178
+ self,
179
+ patch_size=(2, 2),
180
+ in_chans=4,
181
+ embed_dim=1408,
182
+ bias=True,
183
+ ):
184
+ super().__init__()
185
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
186
+
187
+ def forward(self, x):
188
+ x = self.proj(x)
189
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
190
+ return x
191
+
192
+
193
+ def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
194
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
195
+ if not repeat_only:
196
+ half = dim // 2
197
+ freqs = torch.exp(
198
+ -math.log(max_period)
199
+ * torch.arange(start=0, end=half, dtype=torch.float32)
200
+ / half
201
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
202
+ args = t[:, None].float() * freqs[None]
203
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
204
+ if dim % 2:
205
+ embedding = torch.cat(
206
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
207
+ )
208
+ else:
209
+ embedding = repeat(t, "b -> b d", d=dim)
210
+ return embedding
211
+
212
+
213
+ class TimestepEmbedder(torch.nn.Module):
214
+ def __init__(self, hidden_size=1408, frequency_embedding_size=256):
215
+ super().__init__()
216
+ self.mlp = torch.nn.Sequential(
217
+ torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
218
+ torch.nn.SiLU(),
219
+ torch.nn.Linear(hidden_size, hidden_size, bias=True),
220
+ )
221
+ self.frequency_embedding_size = frequency_embedding_size
222
+
223
+ def forward(self, t):
224
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
225
+ t_emb = self.mlp(t_freq)
226
+ return t_emb
227
+
228
+
229
+ class HunyuanDiT(torch.nn.Module):
230
+ def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
231
+ super().__init__()
232
+
233
+ # Embedders
234
+ self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
235
+ self.t5_embedder = torch.nn.Sequential(
236
+ torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
237
+ FP32_SiLU(),
238
+ torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
239
+ )
240
+ self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
241
+ self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
242
+ self.patch_embedder = PatchEmbed(in_chans=in_channels)
243
+ self.timestep_embedder = TimestepEmbedder()
244
+ self.extra_embedder = torch.nn.Sequential(
245
+ torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
246
+ FP32_SiLU(),
247
+ torch.nn.Linear(hidden_dim * 4, hidden_dim),
248
+ )
249
+
250
+ # Transformer blocks
251
+ self.num_layers_down = num_layers_down
252
+ self.num_layers_up = num_layers_up
253
+ self.blocks = torch.nn.ModuleList(
254
+ [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
255
+ [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
256
+ )
257
+
258
+ # Output layers
259
+ self.final_layer = HunyuanDiTFinalLayer()
260
+ self.out_channels = out_channels
261
+
262
+ def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
263
+ text_emb_mask = text_emb_mask.bool()
264
+ text_emb_mask_t5 = text_emb_mask_t5.bool()
265
+ text_emb_t5 = self.t5_embedder(text_emb_t5)
266
+ text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
267
+ text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
268
+ text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
269
+ return text_emb
270
+
271
+ def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
272
+ # Text embedding
273
+ pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
274
+
275
+ # Timestep embedding
276
+ timestep_emb = self.timestep_embedder(timestep)
277
+
278
+ # Size embedding
279
+ size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
280
+ size_emb = size_emb.view(-1, 6 * 256)
281
+
282
+ # Style embedding
283
+ style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
284
+
285
+ # Concatenate all extra vectors
286
+ extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
287
+ condition_emb = timestep_emb + self.extra_embedder(extra_emb)
288
+
289
+ return condition_emb
290
+
291
+ def unpatchify(self, x, h, w):
292
+ return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
293
+
294
+ def build_mask(self, data, is_bound):
295
+ _, _, H, W = data.shape
296
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
297
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
298
+ border_width = (H + W) // 4
299
+ pad = torch.ones_like(h) * border_width
300
+ mask = torch.stack([
301
+ pad if is_bound[0] else h + 1,
302
+ pad if is_bound[1] else H - h,
303
+ pad if is_bound[2] else w + 1,
304
+ pad if is_bound[3] else W - w
305
+ ]).min(dim=0).values
306
+ mask = mask.clip(1, border_width)
307
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
308
+ mask = rearrange(mask, "H W -> 1 H W")
309
+ return mask
310
+
311
+ def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
312
+ B, C, H, W = hidden_states.shape
313
+
314
+ weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
315
+ values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
316
+
317
+ # Split tasks
318
+ tasks = []
319
+ for h in range(0, H, tile_stride):
320
+ for w in range(0, W, tile_stride):
321
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
322
+ continue
323
+ h_, w_ = h + tile_size, w + tile_size
324
+ if h_ > H: h, h_ = H - tile_size, H
325
+ if w_ > W: w, w_ = W - tile_size, W
326
+ tasks.append((h, h_, w, w_))
327
+
328
+ # Run
329
+ for hl, hr, wl, wr in tasks:
330
+ hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
331
+ hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
332
+ if residual is not None:
333
+ residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
334
+ residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
335
+ else:
336
+ residual_batch = None
337
+
338
+ # Forward
339
+ hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
340
+ hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
341
+
342
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
343
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
344
+ weight[:, :, hl:hr, wl:wr] += mask
345
+ values /= weight
346
+ return values
347
+
348
+ def forward(
349
+ self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
350
+ tiled=False, tile_size=64, tile_stride=32,
351
+ to_cache=False,
352
+ use_gradient_checkpointing=False,
353
+ ):
354
+ # Embeddings
355
+ text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
356
+ condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
357
+
358
+ # Input
359
+ height, width = hidden_states.shape[-2], hidden_states.shape[-1]
360
+ hidden_states = self.patch_embedder(hidden_states)
361
+
362
+ # Blocks
363
+ def create_custom_forward(module):
364
+ def custom_forward(*inputs):
365
+ return module(*inputs)
366
+ return custom_forward
367
+ if tiled:
368
+ hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
369
+ residuals = []
370
+ for block_id, block in enumerate(self.blocks):
371
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
372
+ hidden_states = self.tiled_block_forward(
373
+ block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
374
+ torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
375
+ tile_size=tile_size, tile_stride=tile_stride
376
+ )
377
+ if block_id < self.num_layers_down - 2:
378
+ residuals.append(hidden_states)
379
+ hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
380
+ else:
381
+ residuals = []
382
+ for block_id, block in enumerate(self.blocks):
383
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
384
+ if self.training and use_gradient_checkpointing:
385
+ hidden_states = torch.utils.checkpoint.checkpoint(
386
+ create_custom_forward(block),
387
+ hidden_states, condition_emb, text_emb, freq_cis_img, residual,
388
+ use_reentrant=False,
389
+ )
390
+ else:
391
+ hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
392
+ if block_id < self.num_layers_down - 2:
393
+ residuals.append(hidden_states)
394
+
395
+ # Output
396
+ hidden_states = self.final_layer(hidden_states, condition_emb)
397
+ hidden_states = self.unpatchify(hidden_states, height//2, width//2)
398
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
399
+ return hidden_states
400
+
401
+ @staticmethod
402
+ def state_dict_converter():
403
+ return HunyuanDiTStateDictConverter()
404
+
405
+
406
+
407
+ class HunyuanDiTStateDictConverter():
408
+ def __init__(self):
409
+ pass
410
+
411
+ def from_diffusers(self, state_dict):
412
+ state_dict_ = {}
413
+ for name, param in state_dict.items():
414
+ name_ = name
415
+ name_ = name_.replace(".default_modulation.", ".modulation.")
416
+ name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
417
+ name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
418
+ name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
419
+ name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
420
+ name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
421
+ name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
422
+ name_ = name_.replace(".q_proj.", ".to_q.")
423
+ name_ = name_.replace(".out_proj.", ".to_out.")
424
+ name_ = name_.replace("text_embedding_padding", "text_emb_padding")
425
+ name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
426
+ name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
427
+ name_ = name_.replace("pooler.", "t5_pooler.")
428
+ name_ = name_.replace("x_embedder.", "patch_embedder.")
429
+ name_ = name_.replace("t_embedder.", "timestep_embedder.")
430
+ name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
431
+ name_ = name_.replace("style_embedder.weight", "style_embedder")
432
+ if ".kv_proj." in name_:
433
+ param_k = param[:param.shape[0]//2]
434
+ param_v = param[param.shape[0]//2:]
435
+ state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
436
+ state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
437
+ elif ".Wqkv." in name_:
438
+ param_q = param[:param.shape[0]//3]
439
+ param_k = param[param.shape[0]//3:param.shape[0]//3*2]
440
+ param_v = param[param.shape[0]//3*2:]
441
+ state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
442
+ state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
443
+ state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
444
+ elif "style_embedder" in name_:
445
+ state_dict_[name_] = param.squeeze()
446
+ else:
447
+ state_dict_[name_] = param
448
+ return state_dict_
449
+
450
+ def from_civitai(self, state_dict):
451
+ return self.from_diffusers(state_dict)
@@ -0,0 +1,163 @@
1
+ from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
2
+ import torch
3
+
4
+
5
+
6
+ class HunyuanDiTCLIPTextEncoder(BertModel):
7
+ def __init__(self):
8
+ config = BertConfig(
9
+ _name_or_path = "",
10
+ architectures = ["BertModel"],
11
+ attention_probs_dropout_prob = 0.1,
12
+ bos_token_id = 0,
13
+ classifier_dropout = None,
14
+ directionality = "bidi",
15
+ eos_token_id = 2,
16
+ hidden_act = "gelu",
17
+ hidden_dropout_prob = 0.1,
18
+ hidden_size = 1024,
19
+ initializer_range = 0.02,
20
+ intermediate_size = 4096,
21
+ layer_norm_eps = 1e-12,
22
+ max_position_embeddings = 512,
23
+ model_type = "bert",
24
+ num_attention_heads = 16,
25
+ num_hidden_layers = 24,
26
+ output_past = True,
27
+ pad_token_id = 0,
28
+ pooler_fc_size = 768,
29
+ pooler_num_attention_heads = 12,
30
+ pooler_num_fc_layers = 3,
31
+ pooler_size_per_head = 128,
32
+ pooler_type = "first_token_transform",
33
+ position_embedding_type = "absolute",
34
+ torch_dtype = "float32",
35
+ transformers_version = "4.37.2",
36
+ type_vocab_size = 2,
37
+ use_cache = True,
38
+ vocab_size = 47020
39
+ )
40
+ super().__init__(config, add_pooling_layer=False)
41
+ self.eval()
42
+
43
+ def forward(self, input_ids, attention_mask, clip_skip=1):
44
+ input_shape = input_ids.size()
45
+
46
+ batch_size, seq_length = input_shape
47
+ device = input_ids.device
48
+
49
+ past_key_values_length = 0
50
+
51
+ if attention_mask is None:
52
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
53
+
54
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
55
+
56
+ embedding_output = self.embeddings(
57
+ input_ids=input_ids,
58
+ position_ids=None,
59
+ token_type_ids=None,
60
+ inputs_embeds=None,
61
+ past_key_values_length=0,
62
+ )
63
+ encoder_outputs = self.encoder(
64
+ embedding_output,
65
+ attention_mask=extended_attention_mask,
66
+ head_mask=None,
67
+ encoder_hidden_states=None,
68
+ encoder_attention_mask=None,
69
+ past_key_values=None,
70
+ use_cache=False,
71
+ output_attentions=False,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ )
75
+ all_hidden_states = encoder_outputs.hidden_states
76
+ prompt_emb = all_hidden_states[-clip_skip]
77
+ if clip_skip > 1:
78
+ mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
79
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
80
+ return prompt_emb
81
+
82
+ @staticmethod
83
+ def state_dict_converter():
84
+ return HunyuanDiTCLIPTextEncoderStateDictConverter()
85
+
86
+
87
+
88
+ class HunyuanDiTT5TextEncoder(T5EncoderModel):
89
+ def __init__(self):
90
+ config = T5Config(
91
+ _name_or_path = "../HunyuanDiT/t2i/mt5",
92
+ architectures = ["MT5ForConditionalGeneration"],
93
+ classifier_dropout = 0.0,
94
+ d_ff = 5120,
95
+ d_kv = 64,
96
+ d_model = 2048,
97
+ decoder_start_token_id = 0,
98
+ dense_act_fn = "gelu_new",
99
+ dropout_rate = 0.1,
100
+ eos_token_id = 1,
101
+ feed_forward_proj = "gated-gelu",
102
+ initializer_factor = 1.0,
103
+ is_encoder_decoder = True,
104
+ is_gated_act = True,
105
+ layer_norm_epsilon = 1e-06,
106
+ model_type = "t5",
107
+ num_decoder_layers = 24,
108
+ num_heads = 32,
109
+ num_layers = 24,
110
+ output_past = True,
111
+ pad_token_id = 0,
112
+ relative_attention_max_distance = 128,
113
+ relative_attention_num_buckets = 32,
114
+ tie_word_embeddings = False,
115
+ tokenizer_class = "T5Tokenizer",
116
+ transformers_version = "4.37.2",
117
+ use_cache = True,
118
+ vocab_size = 250112
119
+ )
120
+ super().__init__(config)
121
+ self.eval()
122
+
123
+ def forward(self, input_ids, attention_mask, clip_skip=1):
124
+ outputs = super().forward(
125
+ input_ids=input_ids,
126
+ attention_mask=attention_mask,
127
+ output_hidden_states=True,
128
+ )
129
+ prompt_emb = outputs.hidden_states[-clip_skip]
130
+ if clip_skip > 1:
131
+ mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
132
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
133
+ return prompt_emb
134
+
135
+ @staticmethod
136
+ def state_dict_converter():
137
+ return HunyuanDiTT5TextEncoderStateDictConverter()
138
+
139
+
140
+
141
+ class HunyuanDiTCLIPTextEncoderStateDictConverter():
142
+ def __init__(self):
143
+ pass
144
+
145
+ def from_diffusers(self, state_dict):
146
+ state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
147
+ return state_dict_
148
+
149
+ def from_civitai(self, state_dict):
150
+ return self.from_diffusers(state_dict)
151
+
152
+
153
+ class HunyuanDiTT5TextEncoderStateDictConverter():
154
+ def __init__(self):
155
+ pass
156
+
157
+ def from_diffusers(self, state_dict):
158
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
159
+ state_dict_["shared.weight"] = state_dict["shared.weight"]
160
+ return state_dict_
161
+
162
+ def from_civitai(self, state_dict):
163
+ return self.from_diffusers(state_dict)