diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict
5
+
6
+ from diffsynth_engine.models.components.clip import CLIPEncoderLayer
7
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
8
+ from diffsynth_engine.models.utils import no_init_weights
9
+ from diffsynth_engine.utils.constants import SDXL_TEXT_ENCODER_CONFIG_FILE
10
+ from diffsynth_engine.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ with open(SDXL_TEXT_ENCODER_CONFIG_FILE, "r") as f:
15
+ config = json.load(f)
16
+
17
+
18
+ class SDXLTextEncoderStateDictConverter(StateDictConverter):
19
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
20
+ rename_dict = config["diffusers"]["te1_rename_dict"]
21
+ attn_rename_dict = config["diffusers"]["te1_attn_rename_dict"]
22
+ state_dict_ = {}
23
+ for name in state_dict:
24
+ if name in rename_dict:
25
+ param = state_dict[name]
26
+ if name == "text_model.embeddings.position_embedding.weight":
27
+ param = param.reshape((1, param.shape[0], param.shape[1]))
28
+ state_dict_[rename_dict[name]] = param
29
+ elif name.startswith("text_model.encoder.layers."):
30
+ param = state_dict[name]
31
+ names = name.split(".")
32
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
33
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
34
+ if layer_id == "11":
35
+ # we don't need the last layer
36
+ continue
37
+ state_dict_[name_] = param
38
+ return state_dict_
39
+
40
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
41
+ rename_dict = config["civitai"]["te1_rename_dict"]
42
+ state_dict_ = {}
43
+ for name, param in state_dict.items():
44
+ if not name.startswith("conditioner.embedders.0"):
45
+ continue
46
+ if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
47
+ param = param.reshape((1, param.shape[0], param.shape[1]))
48
+ name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding"
49
+ suffix = ""
50
+ else:
51
+ name, suffix = split_suffix(name)
52
+ if name in rename_dict:
53
+ name_ = rename_dict[name] + suffix
54
+ state_dict_[name_] = param
55
+ return state_dict_
56
+
57
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
58
+ if "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight" in state_dict:
59
+ state_dict = self._from_civitai(state_dict)
60
+ logger.info("use civitai format state dict")
61
+ elif "text_model.final_layer_norm.weight" in state_dict:
62
+ state_dict = self._from_diffusers(state_dict)
63
+ logger.info("use diffusers format state dict")
64
+ else:
65
+ logger.info("use diffsynth format state dict")
66
+ return state_dict
67
+
68
+
69
+ class SDXLTextEncoder2StateDictConverter(StateDictConverter):
70
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
71
+ rename_dict = config["diffusers"]["te2_rename_dict"]
72
+ attn_rename_dict = config["diffusers"]["te2_attn_rename_dict"]
73
+ state_dict_ = {}
74
+ for name in state_dict:
75
+ if name in rename_dict:
76
+ param = state_dict[name]
77
+ if name == "text_model.embeddings.position_embedding.weight":
78
+ param = param.reshape((1, param.shape[0], param.shape[1]))
79
+ state_dict_[rename_dict[name]] = param
80
+ elif name.startswith("text_model.encoder.layers."):
81
+ param = state_dict[name]
82
+ names = name.split(".")
83
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
84
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
85
+ state_dict_[name_] = param
86
+ return state_dict_
87
+
88
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
89
+ rename_dict = config["civitai"]["te2_rename_dict"]
90
+ state_dict_ = {}
91
+ for name, param in state_dict.items():
92
+ if not name.startswith("conditioner.embedders.1"):
93
+ continue
94
+ if name.endswith(".in_proj_weight"):
95
+ name = name.replace(".in_proj_weight", ".in_proj")
96
+ suffix = ".weight"
97
+ elif name.endswith(".in_proj_bias"):
98
+ name = name.replace(".in_proj_bias", ".in_proj")
99
+ suffix = ".bias"
100
+ elif name == "conditioner.embedders.1.model.text_projection":
101
+ name = "conditioner.embedders.1.model.text_projection"
102
+ suffix = ".weight"
103
+ param = param.T
104
+ elif name == "conditioner.embedders.1.model.positional_embedding":
105
+ name = "conditioner.embedders.1.model.positional_embedding"
106
+ suffix = ""
107
+ param = param.reshape((1, param.shape[0], param.shape[1]))
108
+ else:
109
+ name, suffix = split_suffix(name)
110
+
111
+ if name in rename_dict:
112
+ if isinstance(rename_dict[name], str):
113
+ name_ = rename_dict[name] + suffix
114
+ state_dict_[name_] = param
115
+ else:
116
+ length = param.shape[0] // 3
117
+ for i, rename in enumerate(rename_dict[name]):
118
+ name_ = rename + suffix
119
+ state_dict_[name_] = param[i * length : i * length + length]
120
+ return state_dict_
121
+
122
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
123
+ if "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight" in state_dict:
124
+ state_dict = self._from_civitai(state_dict)
125
+ logger.info("use civitai format state dict")
126
+ elif "text_model.final_layer_norm.weight" in state_dict:
127
+ state_dict = self._from_diffusers(state_dict)
128
+ logger.info("use diffusers format state dict")
129
+ else:
130
+ logger.info("use diffsynth format state dict")
131
+ return state_dict
132
+
133
+
134
+ class SDXLTextEncoder(PreTrainedModel):
135
+ converter = SDXLTextEncoderStateDictConverter()
136
+
137
+ def __init__(
138
+ self,
139
+ embed_dim=768,
140
+ vocab_size=49408,
141
+ max_position_embeddings=77,
142
+ num_encoder_layers=11,
143
+ encoder_intermediate_size=3072,
144
+ device: str = "cuda:0",
145
+ dtype: torch.dtype = torch.float16,
146
+ ):
147
+ super().__init__()
148
+
149
+ # token_embedding
150
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device, dtype=dtype)
151
+
152
+ # position_embeds (This is a fixed tensor)
153
+ self.position_embeds = nn.Parameter(
154
+ torch.zeros(1, max_position_embeddings, embed_dim, device=device, dtype=dtype)
155
+ )
156
+
157
+ # encoders
158
+ self.encoders = nn.ModuleList(
159
+ [
160
+ CLIPEncoderLayer(embed_dim, encoder_intermediate_size, device=device, dtype=dtype)
161
+ for _ in range(num_encoder_layers)
162
+ ]
163
+ )
164
+
165
+ # attn_mask
166
+ self.attn_mask = self.attention_mask(max_position_embeddings)
167
+
168
+ # The text encoder is different to that in Stable Diffusion 1.x.
169
+ # It does not include final_layer_norm.
170
+
171
+ def attention_mask(self, length):
172
+ mask = torch.empty(length, length)
173
+ mask.fill_(float("-inf"))
174
+ mask.triu_(1)
175
+ return mask
176
+
177
+ def forward(self, input_ids, clip_skip=2):
178
+ clip_skip = max(
179
+ clip_skip - 1, 1
180
+ ) # Because we did not load the last layer of the encoder, the clip_skip needs to be decreased by 1.
181
+ embeds = self.token_embedding(input_ids) + self.position_embeds
182
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
183
+ for encoder_id, encoder in enumerate(self.encoders):
184
+ embeds = encoder(embeds, attn_mask=attn_mask)
185
+ if encoder_id + clip_skip == len(self.encoders):
186
+ break
187
+ return embeds
188
+
189
+ @classmethod
190
+ def from_state_dict(
191
+ cls,
192
+ state_dict: Dict[str, torch.Tensor],
193
+ device: str,
194
+ dtype: torch.dtype,
195
+ embed_dim: int = 768,
196
+ vocab_size: int = 49408,
197
+ max_position_embeddings: int = 77,
198
+ num_encoder_layers: int = 11,
199
+ encoder_intermediate_size: int = 3072,
200
+ ):
201
+ with no_init_weights():
202
+ model = torch.nn.utils.skip_init(
203
+ cls,
204
+ device=device,
205
+ dtype=dtype,
206
+ embed_dim=embed_dim,
207
+ vocab_size=vocab_size,
208
+ max_position_embeddings=max_position_embeddings,
209
+ num_encoder_layers=num_encoder_layers,
210
+ encoder_intermediate_size=encoder_intermediate_size,
211
+ )
212
+ model.load_state_dict(state_dict)
213
+ return model
214
+
215
+
216
+ class SDXLTextEncoder2(PreTrainedModel):
217
+ converter = SDXLTextEncoder2StateDictConverter()
218
+
219
+ def __init__(
220
+ self,
221
+ embed_dim=1280,
222
+ vocab_size=49408,
223
+ max_position_embeddings=77,
224
+ num_encoder_layers=32,
225
+ encoder_intermediate_size=5120,
226
+ device: str = "cuda:0",
227
+ dtype: torch.dtype = torch.float16,
228
+ ):
229
+ super().__init__()
230
+
231
+ # token_embedding
232
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device, dtype=dtype)
233
+
234
+ # position_embeds (This is a fixed tensor)
235
+ self.position_embeds = nn.Parameter(
236
+ torch.zeros(1, max_position_embeddings, embed_dim, device=device, dtype=dtype)
237
+ )
238
+
239
+ # encoders
240
+ self.encoders = nn.ModuleList(
241
+ [
242
+ CLIPEncoderLayer(
243
+ embed_dim,
244
+ encoder_intermediate_size,
245
+ num_heads=20,
246
+ head_dim=64,
247
+ use_quick_gelu=False,
248
+ device=device,
249
+ dtype=dtype,
250
+ )
251
+ for _ in range(num_encoder_layers)
252
+ ]
253
+ )
254
+
255
+ # attn_mask
256
+ self.attn_mask = self.attention_mask(max_position_embeddings)
257
+
258
+ # final_layer_norm
259
+ self.final_layer_norm = nn.LayerNorm(embed_dim, device=device, dtype=dtype)
260
+
261
+ # text_projection
262
+ self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, device=device, dtype=dtype)
263
+
264
+ def attention_mask(self, length):
265
+ mask = torch.empty(length, length)
266
+ mask.fill_(float("-inf"))
267
+ mask.triu_(1)
268
+ return mask
269
+
270
+ def forward(self, input_ids, clip_skip=2):
271
+ clip_skip = max(clip_skip, 1)
272
+ embeds = self.token_embedding(input_ids) + self.position_embeds
273
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
274
+ for encoder_id, encoder in enumerate(self.encoders):
275
+ embeds = encoder(embeds, attn_mask=attn_mask)
276
+ if encoder_id + clip_skip == len(self.encoders):
277
+ hidden_states = embeds
278
+ embeds = self.final_layer_norm(embeds)
279
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
280
+ pooled_embeds = self.text_projection(pooled_embeds)
281
+ return hidden_states, pooled_embeds
282
+
283
+ @classmethod
284
+ def from_state_dict(
285
+ cls,
286
+ state_dict: Dict[str, torch.Tensor],
287
+ device: str,
288
+ dtype: torch.dtype,
289
+ embed_dim: int = 1280,
290
+ vocab_size: int = 49408,
291
+ max_position_embeddings: int = 77,
292
+ num_encoder_layers: int = 32,
293
+ encoder_intermediate_size: int = 5120,
294
+ ):
295
+ with no_init_weights():
296
+ model = torch.nn.utils.skip_init(
297
+ cls,
298
+ device=device,
299
+ dtype=dtype,
300
+ embed_dim=embed_dim,
301
+ vocab_size=vocab_size,
302
+ max_position_embeddings=max_position_embeddings,
303
+ num_encoder_layers=num_encoder_layers,
304
+ encoder_intermediate_size=encoder_intermediate_size,
305
+ )
306
+ model.load_state_dict(state_dict)
307
+ return model
@@ -0,0 +1,306 @@
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict
5
+
6
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
7
+ from diffsynth_engine.models.basic.unet_helper import (
8
+ ResnetBlock,
9
+ AttentionBlock,
10
+ PushBlock,
11
+ DownSampler,
12
+ PopBlock,
13
+ UpSampler,
14
+ )
15
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
16
+ from diffsynth_engine.models.utils import no_init_weights
17
+ from diffsynth_engine.utils.constants import SDXL_UNET_CONFIG_FILE
18
+ from diffsynth_engine.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ with open(SDXL_UNET_CONFIG_FILE, "r") as f:
23
+ config = json.load(f)
24
+
25
+
26
+ class SDXLUNetStateDictConverter(StateDictConverter):
27
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28
+ # architecture
29
+ block_types = [
30
+ "ResnetBlock",
31
+ "PushBlock",
32
+ "ResnetBlock",
33
+ "PushBlock",
34
+ "DownSampler",
35
+ "PushBlock",
36
+ "ResnetBlock",
37
+ "AttentionBlock",
38
+ "PushBlock",
39
+ "ResnetBlock",
40
+ "AttentionBlock",
41
+ "PushBlock",
42
+ "DownSampler",
43
+ "PushBlock",
44
+ "ResnetBlock",
45
+ "AttentionBlock",
46
+ "PushBlock",
47
+ "ResnetBlock",
48
+ "AttentionBlock",
49
+ "PushBlock",
50
+ "ResnetBlock",
51
+ "AttentionBlock",
52
+ "ResnetBlock",
53
+ "PopBlock",
54
+ "ResnetBlock",
55
+ "AttentionBlock",
56
+ "PopBlock",
57
+ "ResnetBlock",
58
+ "AttentionBlock",
59
+ "PopBlock",
60
+ "ResnetBlock",
61
+ "AttentionBlock",
62
+ "UpSampler",
63
+ "PopBlock",
64
+ "ResnetBlock",
65
+ "AttentionBlock",
66
+ "PopBlock",
67
+ "ResnetBlock",
68
+ "AttentionBlock",
69
+ "PopBlock",
70
+ "ResnetBlock",
71
+ "AttentionBlock",
72
+ "UpSampler",
73
+ "PopBlock",
74
+ "ResnetBlock",
75
+ "PopBlock",
76
+ "ResnetBlock",
77
+ "PopBlock",
78
+ "ResnetBlock",
79
+ ]
80
+
81
+ # rename each parameter
82
+ name_list = sorted([name for name in state_dict])
83
+ rename_dict = {}
84
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
85
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
86
+ for name in name_list:
87
+ names = name.split(".")
88
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
89
+ pass
90
+ elif names[0] in ["encoder_hid_proj"]:
91
+ names[0] = "text_intermediate_proj"
92
+ elif names[0] in ["time_embedding", "add_embedding"]:
93
+ if names[0] == "add_embedding":
94
+ names[0] = "add_time_embedding"
95
+ else:
96
+ names[0] = "time_embedding.timestep_embedder"
97
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
98
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
99
+ if names[0] == "mid_block":
100
+ names.insert(1, "0")
101
+ block_type = {
102
+ "resnets": "ResnetBlock",
103
+ "attentions": "AttentionBlock",
104
+ "downsamplers": "DownSampler",
105
+ "upsamplers": "UpSampler",
106
+ }[names[2]]
107
+ block_type_with_id = ".".join(names[:4])
108
+ if block_type_with_id != last_block_type_with_id[block_type]:
109
+ block_id[block_type] += 1
110
+ last_block_type_with_id[block_type] = block_type_with_id
111
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
112
+ block_id[block_type] += 1
113
+ block_type_with_id = ".".join(names[:4])
114
+ names = ["blocks", str(block_id[block_type])] + names[4:]
115
+ if "ff" in names:
116
+ ff_index = names.index("ff")
117
+ component = ".".join(names[ff_index : ff_index + 3])
118
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
119
+ names = names[:ff_index] + [component] + names[ff_index + 3 :]
120
+ if "to_out" in names:
121
+ names.pop(names.index("to_out") + 1)
122
+ else:
123
+ raise ValueError(f"Unknown parameters: {name}")
124
+ rename_dict[name] = ".".join(names)
125
+
126
+ # convert state_dict
127
+ state_dict_ = {}
128
+ for name, param in state_dict.items():
129
+ if ".proj_in." in name or ".proj_out." in name:
130
+ param = param.squeeze()
131
+ state_dict_[rename_dict[name]] = param
132
+ if "text_intermediate_proj.weight" in state_dict_:
133
+ return state_dict_, {"is_kolors": True}
134
+ else:
135
+ return state_dict_
136
+
137
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
138
+ rename_dict = config["civitai"]["rename_dict"]
139
+ state_dict_ = {}
140
+ for name, param in state_dict.items():
141
+ name, suffix = split_suffix(name)
142
+ if name in rename_dict:
143
+ if ".proj_in." in name or ".proj_out." in name:
144
+ param = param.squeeze()
145
+ name_ = rename_dict[name] + suffix
146
+ state_dict_[name_] = param
147
+ return state_dict_
148
+
149
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
150
+ if "model.diffusion_model.input_blocks.0.0.weight" in state_dict:
151
+ state_dict = self._from_civitai(state_dict)
152
+ logger.info("use civitai format state dict")
153
+ elif "down_blocks.0.resnets.0.conv1.weight" in state_dict:
154
+ state_dict = self._from_diffusers(state_dict)
155
+ logger.info("use diffusers format state dict")
156
+ else:
157
+ logger.info("user diffsynth format state dict")
158
+ return state_dict
159
+
160
+
161
+ class SDXLUNet(PreTrainedModel):
162
+ converter = SDXLUNetStateDictConverter()
163
+
164
+ def __init__(
165
+ self,
166
+ is_kolors: bool = False,
167
+ device: str = "cuda:0",
168
+ dtype: torch.dtype = torch.float16,
169
+ use_gradient_checkpointing: bool = False,
170
+ ):
171
+ super().__init__()
172
+ self.use_gradient_checkpointing = use_gradient_checkpointing
173
+ self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
174
+ self.add_time_embedding = nn.Sequential(
175
+ nn.Linear(5632 if is_kolors else 2816, 1280, device=device, dtype=dtype),
176
+ nn.SiLU(),
177
+ nn.Linear(1280, 1280, device=device, dtype=dtype),
178
+ )
179
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
180
+ self.text_intermediate_proj = nn.Linear(4096, 2048, device=device, dtype=dtype) if is_kolors else None
181
+
182
+ self.blocks = nn.ModuleList(
183
+ [
184
+ # DownBlock2D
185
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
186
+ PushBlock(),
187
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
188
+ PushBlock(),
189
+ DownSampler(320, device=device, dtype=dtype),
190
+ PushBlock(),
191
+ # CrossAttnDownBlock2D
192
+ ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
193
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
194
+ PushBlock(),
195
+ ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
196
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
197
+ PushBlock(),
198
+ DownSampler(640, device=device, dtype=dtype),
199
+ PushBlock(),
200
+ # CrossAttnDownBlock2D
201
+ ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
202
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
203
+ PushBlock(),
204
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
205
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
206
+ PushBlock(),
207
+ # UNetMidBlock2DCrossAttn
208
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
209
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
210
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
211
+ # CrossAttnUpBlock2D
212
+ PopBlock(),
213
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
214
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
215
+ PopBlock(),
216
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
217
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
218
+ PopBlock(),
219
+ ResnetBlock(1920, 1280, 1280, device=device, dtype=dtype),
220
+ AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
221
+ UpSampler(1280, device=device, dtype=dtype),
222
+ # CrossAttnUpBlock2D
223
+ PopBlock(),
224
+ ResnetBlock(1920, 640, 1280, device=device, dtype=dtype),
225
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
226
+ PopBlock(),
227
+ ResnetBlock(1280, 640, 1280, device=device, dtype=dtype),
228
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
229
+ PopBlock(),
230
+ ResnetBlock(960, 640, 1280, device=device, dtype=dtype),
231
+ AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
232
+ UpSampler(640, device=device, dtype=dtype),
233
+ # UpBlock2D
234
+ PopBlock(),
235
+ ResnetBlock(960, 320, 1280, device=device, dtype=dtype),
236
+ PopBlock(),
237
+ ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
238
+ PopBlock(),
239
+ ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
240
+ ]
241
+ )
242
+
243
+ self.conv_norm_out = nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5, device=device, dtype=dtype)
244
+ self.conv_act = nn.SiLU()
245
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1, device=device, dtype=dtype)
246
+
247
+ self.is_kolors = is_kolors
248
+
249
+ def forward(self, x, timestep, context, y, **kwargs):
250
+ # 1. time embedding
251
+ t_emb = self.time_embedding(timestep, dtype=x.dtype)
252
+ ## add embedding
253
+ add_embeds = self.add_time_embedding(y)
254
+
255
+ time_emb = t_emb + add_embeds
256
+
257
+ # 2. pre-process
258
+ hidden_states = self.conv_in(x)
259
+ text_emb = context if self.text_intermediate_proj is None else self.text_intermediate_proj(context)
260
+ res_stack = [hidden_states]
261
+
262
+ # 3. blocks
263
+ def create_custom_forward(module):
264
+ def custom_forward(*inputs):
265
+ return module(*inputs)
266
+
267
+ return custom_forward
268
+
269
+ for i, block in enumerate(self.blocks):
270
+ if (
271
+ self.training
272
+ and self.use_gradient_checkpointing
273
+ and not (isinstance(block, PushBlock) or isinstance(block, PopBlock))
274
+ ):
275
+ hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
276
+ create_custom_forward(block),
277
+ hidden_states,
278
+ time_emb,
279
+ text_emb,
280
+ res_stack,
281
+ use_reentrant=False,
282
+ )
283
+ else:
284
+ hidden_states, time_emb, text_emb, res_stack = block(
285
+ hidden_states,
286
+ time_emb,
287
+ text_emb,
288
+ res_stack,
289
+ )
290
+
291
+ # 4. output
292
+ hidden_states = self.conv_norm_out(hidden_states)
293
+ hidden_states = self.conv_act(hidden_states)
294
+ hidden_states = self.conv_out(hidden_states)
295
+
296
+ return hidden_states
297
+
298
+ @classmethod
299
+ def from_state_dict(
300
+ cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, is_kolors: bool = False
301
+ ):
302
+ with no_init_weights():
303
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, is_kolors=is_kolors)
304
+ model.load_state_dict(state_dict, assign=True)
305
+ model.to(device=device, dtype=dtype, non_blocking=True)
306
+ return model
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from typing import Dict
3
+
4
+ from diffsynth_engine.models.components.vae import VAEDecoder, VAEEncoder
5
+ from diffsynth_engine.models.utils import no_init_weights
6
+
7
+
8
+ class SDXLVAEEncoder(VAEEncoder):
9
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
10
+ super().__init__(
11
+ latent_channels=4, scaling_factor=0.13025, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
12
+ )
13
+
14
+ @classmethod
15
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
16
+ with no_init_weights():
17
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
18
+ model.load_state_dict(state_dict)
19
+ return model
20
+
21
+
22
+ class SDXLVAEDecoder(VAEDecoder):
23
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
24
+ super().__init__(
25
+ latent_channels=4,
26
+ scaling_factor=0.13025,
27
+ shift_factor=0,
28
+ use_post_quant_conv=True,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
+
33
+ @classmethod
34
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
35
+ with no_init_weights():
36
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
37
+ model.load_state_dict(state_dict)
38
+ return model