diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,142 @@
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict
5
+
6
+ from diffsynth_engine.models.components.clip import CLIPEncoderLayer
7
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
8
+ from diffsynth_engine.models.utils import no_init_weights
9
+ from diffsynth_engine.utils.constants import SD_TEXT_ENCODER_CONFIG_FILE
10
+ from diffsynth_engine.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ with open(SD_TEXT_ENCODER_CONFIG_FILE, "r") as f:
15
+ config = json.load(f)
16
+
17
+
18
+ class SDTextEncoderStateDictConverter(StateDictConverter):
19
+ def __init__(self):
20
+ pass
21
+
22
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
23
+ rename_dict = config["diffusers"]["rename_dict"]
24
+ attn_rename_dict = config["diffusers"]["attn_rename_dict"]
25
+ state_dict_ = {}
26
+ for name, param in state_dict.items():
27
+ if name in rename_dict:
28
+ if name == "text_model.embeddings.position_embedding.weight":
29
+ param = param.reshape((1, param.shape[0], param.shape[1]))
30
+ state_dict_[rename_dict[name]] = param
31
+ elif name.startswith("text_model.encoder.layers."):
32
+ names = name.split(".")
33
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
34
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
35
+ state_dict_[name_] = param
36
+ return state_dict_
37
+
38
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
39
+ rename_dict = config["civitai"]["rename_dict"]
40
+ state_dict_ = {}
41
+ for name, param in state_dict.items():
42
+ if name not in rename_dict:
43
+ continue
44
+ name_ = rename_dict[name]
45
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
46
+ param = param.reshape((1, param.shape[0], param.shape[1]))
47
+ state_dict_[name_] = param
48
+ return state_dict_
49
+
50
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
51
+ if "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight" in state_dict:
52
+ state_dict = self._from_civitai(state_dict)
53
+ logger.info("use civitai format state dict")
54
+ elif "text_model.encoder.layers.0.layer_norm1.weight" in state_dict:
55
+ state_dict = self._from_diffusers(state_dict)
56
+ logger.info("use diffusers format state dict")
57
+ else:
58
+ logger.info("use diffsynth format state dict")
59
+ return state_dict
60
+
61
+
62
+ class SDTextEncoder(PreTrainedModel):
63
+ converter = SDTextEncoderStateDictConverter()
64
+
65
+ def __init__(
66
+ self,
67
+ embed_dim=768,
68
+ vocab_size=49408,
69
+ max_position_embeddings=77,
70
+ num_encoder_layers=12,
71
+ encoder_intermediate_size=3072,
72
+ device: str = "cuda:0",
73
+ dtype: torch.dtype = torch.float16,
74
+ ):
75
+ super().__init__()
76
+
77
+ # token_embedding
78
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device, dtype=dtype)
79
+
80
+ # position_embeds (This is a fixed tensor)
81
+ self.position_embeds = nn.Parameter(
82
+ torch.zeros(1, max_position_embeddings, embed_dim, device=device, dtype=dtype)
83
+ )
84
+
85
+ # encoders
86
+ self.encoders = nn.ModuleList(
87
+ [
88
+ CLIPEncoderLayer(embed_dim, encoder_intermediate_size, device=device, dtype=dtype)
89
+ for _ in range(num_encoder_layers)
90
+ ]
91
+ )
92
+
93
+ # attn_mask
94
+ self.attn_mask = self.attention_mask(max_position_embeddings)
95
+
96
+ # final_layer_norm
97
+ self.final_layer_norm = nn.LayerNorm(embed_dim, device=device, dtype=dtype)
98
+
99
+ def attention_mask(self, length):
100
+ mask = torch.empty(length, length)
101
+ mask.fill_(float("-inf"))
102
+ mask.triu_(1)
103
+ return mask
104
+
105
+ def forward(self, input_ids, clip_skip=1):
106
+ clip_skip = max(clip_skip, 1)
107
+ embeds = self.token_embedding(input_ids)
108
+ embeds += self.position_embeds.to(device=embeds.device, dtype=embeds.dtype)
109
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
110
+
111
+ for encoder_id, encoder in enumerate(self.encoders):
112
+ embeds = encoder(embeds, attn_mask=attn_mask)
113
+ if encoder_id + clip_skip == len(self.encoders):
114
+ break
115
+ embeds = self.final_layer_norm(embeds)
116
+ return embeds
117
+
118
+ @classmethod
119
+ def from_state_dict(
120
+ cls,
121
+ state_dict: Dict[str, torch.Tensor],
122
+ device: str,
123
+ dtype: torch.dtype,
124
+ embed_dim: int = 768,
125
+ vocab_size: int = 49408,
126
+ max_position_embeddings: int = 77,
127
+ num_encoder_layers: int = 12,
128
+ encoder_intermediate_size: int = 3072,
129
+ ):
130
+ with no_init_weights():
131
+ model = torch.nn.utils.skip_init(
132
+ cls,
133
+ device=device,
134
+ dtype=dtype,
135
+ embed_dim=embed_dim,
136
+ vocab_size=vocab_size,
137
+ max_position_embeddings=max_position_embeddings,
138
+ num_encoder_layers=num_encoder_layers,
139
+ encoder_intermediate_size=encoder_intermediate_size,
140
+ )
141
+ model.load_state_dict(state_dict)
142
+ return model
@@ -0,0 +1,293 @@
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict
5
+
6
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
7
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
8
+ from diffsynth_engine.models.utils import no_init_weights
9
+ from diffsynth_engine.models.basic.unet_helper import (
10
+ ResnetBlock,
11
+ AttentionBlock,
12
+ PushBlock,
13
+ DownSampler,
14
+ PopBlock,
15
+ UpSampler,
16
+ )
17
+ from diffsynth_engine.utils.constants import SD_UNET_CONFIG_FILE
18
+ from diffsynth_engine.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ with open(SD_UNET_CONFIG_FILE) as f:
23
+ config = json.load(f)
24
+
25
+
26
+ class SDUNetStateDictConverter(StateDictConverter):
27
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28
+ # architecture
29
+ block_types = [
30
+ "ResnetBlock",
31
+ "AttentionBlock",
32
+ "PushBlock",
33
+ "ResnetBlock",
34
+ "AttentionBlock",
35
+ "PushBlock",
36
+ "DownSampler",
37
+ "PushBlock",
38
+ "ResnetBlock",
39
+ "AttentionBlock",
40
+ "PushBlock",
41
+ "ResnetBlock",
42
+ "AttentionBlock",
43
+ "PushBlock",
44
+ "DownSampler",
45
+ "PushBlock",
46
+ "ResnetBlock",
47
+ "AttentionBlock",
48
+ "PushBlock",
49
+ "ResnetBlock",
50
+ "AttentionBlock",
51
+ "PushBlock",
52
+ "DownSampler",
53
+ "PushBlock",
54
+ "ResnetBlock",
55
+ "PushBlock",
56
+ "ResnetBlock",
57
+ "PushBlock",
58
+ "ResnetBlock",
59
+ "AttentionBlock",
60
+ "ResnetBlock",
61
+ "PopBlock",
62
+ "ResnetBlock",
63
+ "PopBlock",
64
+ "ResnetBlock",
65
+ "PopBlock",
66
+ "ResnetBlock",
67
+ "UpSampler",
68
+ "PopBlock",
69
+ "ResnetBlock",
70
+ "AttentionBlock",
71
+ "PopBlock",
72
+ "ResnetBlock",
73
+ "AttentionBlock",
74
+ "PopBlock",
75
+ "ResnetBlock",
76
+ "AttentionBlock",
77
+ "UpSampler",
78
+ "PopBlock",
79
+ "ResnetBlock",
80
+ "AttentionBlock",
81
+ "PopBlock",
82
+ "ResnetBlock",
83
+ "AttentionBlock",
84
+ "PopBlock",
85
+ "ResnetBlock",
86
+ "AttentionBlock",
87
+ "UpSampler",
88
+ "PopBlock",
89
+ "ResnetBlock",
90
+ "AttentionBlock",
91
+ "PopBlock",
92
+ "ResnetBlock",
93
+ "AttentionBlock",
94
+ "PopBlock",
95
+ "ResnetBlock",
96
+ "AttentionBlock",
97
+ ]
98
+
99
+ # rename each parameter
100
+ name_list = sorted([name for name in state_dict])
101
+ rename_dict = {}
102
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
103
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
104
+ for name in name_list:
105
+ names = name.split(".")
106
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
107
+ pass
108
+ elif names[0] in ["time_embedding", "add_embedding"]:
109
+ if names[0] == "add_embedding":
110
+ names[0] = "add_time_embedding"
111
+ else:
112
+ names[0] = "time_embedding.timestep_embedder"
113
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
114
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
115
+ if names[0] == "mid_block":
116
+ names.insert(1, "0")
117
+ block_type = {
118
+ "resnets": "ResnetBlock",
119
+ "attentions": "AttentionBlock",
120
+ "downsamplers": "DownSampler",
121
+ "upsamplers": "UpSampler",
122
+ }[names[2]]
123
+ block_type_with_id = ".".join(names[:4])
124
+ if block_type_with_id != last_block_type_with_id[block_type]:
125
+ block_id[block_type] += 1
126
+ last_block_type_with_id[block_type] = block_type_with_id
127
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
128
+ block_id[block_type] += 1
129
+ block_type_with_id = ".".join(names[:4])
130
+ names = ["blocks", str(block_id[block_type])] + names[4:]
131
+ if "ff" in names:
132
+ ff_index = names.index("ff")
133
+ component = ".".join(names[ff_index : ff_index + 3])
134
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
135
+ names = names[:ff_index] + [component] + names[ff_index + 3 :]
136
+ if "to_out" in names:
137
+ names.pop(names.index("to_out") + 1)
138
+ else:
139
+ raise ValueError(f"Unknown parameters: {name}")
140
+ rename_dict[name] = ".".join(names)
141
+
142
+ # convert state_dict
143
+ state_dict_ = {}
144
+ for name, param in state_dict.items():
145
+ if ".proj_in." in name or ".proj_out." in name:
146
+ param = param.squeeze()
147
+ state_dict_[rename_dict[name]] = param
148
+ return state_dict_
149
+
150
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
151
+ rename_dict = config["civitai"]["rename_dict"]
152
+ state_dict_ = {}
153
+ for name, param in state_dict.items():
154
+ name, suffix = split_suffix(name)
155
+ if name in rename_dict:
156
+ if ".proj_in" in name or ".proj_out" in name:
157
+ param = param.squeeze()
158
+ new_key = rename_dict[name] + suffix
159
+ state_dict_[new_key] = param
160
+ return state_dict_
161
+
162
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
163
+ if "model.diffusion_model.input_blocks.0.0.weight" in state_dict:
164
+ state_dict = self._from_civitai(state_dict)
165
+ logger.info("use civitai format state dict")
166
+ elif "down_blocks.0.attentions.0.norm.weight" in state_dict:
167
+ state_dict = self._from_diffusers(state_dict)
168
+ logger.info("use diffusers format state dict")
169
+ else:
170
+ logger.info("user diffsynth format state dict")
171
+ return state_dict
172
+
173
+
174
+ class SDUNet(PreTrainedModel):
175
+ converter = SDUNetStateDictConverter()
176
+
177
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
178
+ super().__init__()
179
+ self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
180
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
181
+
182
+ self.blocks = nn.ModuleList(
183
+ [
184
+ # CrossAttnDownBlock2D
185
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
186
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6, device=device, dtype=dtype),
187
+ PushBlock(),
188
+ ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
189
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6, device=device, dtype=dtype),
190
+ PushBlock(),
191
+ DownSampler(320, device=device, dtype=dtype),
192
+ PushBlock(),
193
+ # CrossAttnDownBlock2D
194
+ ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
195
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6, device=device, dtype=dtype),
196
+ PushBlock(),
197
+ ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
198
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6, device=device, dtype=dtype),
199
+ PushBlock(),
200
+ DownSampler(640, device=device, dtype=dtype),
201
+ PushBlock(),
202
+ # CrossAttnDownBlock2D
203
+ ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
204
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
205
+ PushBlock(),
206
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
207
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
208
+ PushBlock(),
209
+ DownSampler(1280, device=device, dtype=dtype),
210
+ PushBlock(),
211
+ # DownBlock2D
212
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
213
+ PushBlock(),
214
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
215
+ PushBlock(),
216
+ # UNetMidBlock2DCrossAttn
217
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
218
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
219
+ ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
220
+ # UpBlock2D
221
+ PopBlock(),
222
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
223
+ PopBlock(),
224
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
225
+ PopBlock(),
226
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
227
+ UpSampler(1280, device=device, dtype=dtype),
228
+ # CrossAttnUpBlock2D
229
+ PopBlock(),
230
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
231
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
232
+ PopBlock(),
233
+ ResnetBlock(2560, 1280, 1280, device=device, dtype=dtype),
234
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
235
+ PopBlock(),
236
+ ResnetBlock(1920, 1280, 1280, device=device, dtype=dtype),
237
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6, device=device, dtype=dtype),
238
+ UpSampler(1280, device=device, dtype=dtype),
239
+ # CrossAttnUpBlock2D
240
+ PopBlock(),
241
+ ResnetBlock(1920, 640, 1280, device=device, dtype=dtype),
242
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6, device=device, dtype=dtype),
243
+ PopBlock(),
244
+ ResnetBlock(1280, 640, 1280, device=device, dtype=dtype),
245
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6, device=device, dtype=dtype),
246
+ PopBlock(),
247
+ ResnetBlock(960, 640, 1280, device=device, dtype=dtype),
248
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6, device=device, dtype=dtype),
249
+ UpSampler(640, device=device, dtype=dtype),
250
+ # CrossAttnUpBlock2D
251
+ PopBlock(),
252
+ ResnetBlock(960, 320, 1280, device=device, dtype=dtype),
253
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6, device=device, dtype=dtype),
254
+ PopBlock(),
255
+ ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
256
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6, device=device, dtype=dtype),
257
+ PopBlock(),
258
+ ResnetBlock(640, 320, 1280, device=device, dtype=dtype),
259
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6, device=device, dtype=dtype),
260
+ ]
261
+ )
262
+
263
+ self.conv_norm_out = nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5, device=device, dtype=dtype)
264
+ self.conv_act = nn.SiLU()
265
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1, device=device, dtype=dtype)
266
+
267
+ def forward(self, x, timestep, context, **kwargs):
268
+ # 1. time
269
+ time_emb = self.time_embedding(timestep, dtype=x.dtype)
270
+
271
+ # 2. pre-process
272
+ hidden_states = self.conv_in(x)
273
+ text_emb = context
274
+ res_stack = [hidden_states]
275
+
276
+ # 3. blocks
277
+ for i, block in enumerate(self.blocks):
278
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
279
+
280
+ # 4. output
281
+ hidden_states = self.conv_norm_out(hidden_states)
282
+ hidden_states = self.conv_act(hidden_states)
283
+ hidden_states = self.conv_out(hidden_states)
284
+
285
+ return hidden_states
286
+
287
+ @classmethod
288
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
289
+ with no_init_weights():
290
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
291
+ model.load_state_dict(state_dict, assign=True)
292
+ model.to(device=device, dtype=dtype, non_blocking=True)
293
+ return model
@@ -0,0 +1,38 @@
1
+ import torch
2
+ from typing import Dict
3
+
4
+ from diffsynth_engine.models.components.vae import VAEDecoder, VAEEncoder
5
+ from diffsynth_engine.models.utils import no_init_weights
6
+
7
+
8
+ class SDVAEEncoder(VAEEncoder):
9
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
10
+ super().__init__(
11
+ latent_channels=4, scaling_factor=0.18215, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
12
+ )
13
+
14
+ @classmethod
15
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
16
+ with no_init_weights():
17
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
18
+ model.load_state_dict(state_dict)
19
+ return model
20
+
21
+
22
+ class SDVAEDecoder(VAEDecoder):
23
+ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
24
+ super().__init__(
25
+ latent_channels=4,
26
+ scaling_factor=0.18215,
27
+ shift_factor=0,
28
+ use_post_quant_conv=True,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
+
33
+ @classmethod
34
+ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
35
+ with no_init_weights():
36
+ model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
37
+ model.load_state_dict(state_dict)
38
+ return model
@@ -0,0 +1,14 @@
1
+ from .sd3_dit import SD3DiT, config as sd3_dit_config
2
+ from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, config as sd3_text_encoder_config
3
+ from .sd3_vae import SD3VAEEncoder, SD3VAEDecoder
4
+
5
+ __all__ = [
6
+ "SD3DiT",
7
+ "SD3TextEncoder1",
8
+ "SD3TextEncoder2",
9
+ "SD3TextEncoder3",
10
+ "SD3VAEEncoder",
11
+ "SD3VAEDecoder",
12
+ "sd3_dit_config",
13
+ "sd3_text_encoder_config",
14
+ ]