diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,393 @@
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict
6
+
7
+ from diffsynth_engine.models.basic.attention import Attention
8
+ from diffsynth_engine.models.basic.unet_helper import ResnetBlock, UpSampler, DownSampler
9
+ from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
10
+ from diffsynth_engine.models.utils import no_init_weights
11
+ from diffsynth_engine.utils.constants import VAE_CONFIG_FILE
12
+ from diffsynth_engine.utils import logging
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ with open(VAE_CONFIG_FILE, "r") as f:
17
+ config = json.load(f)
18
+
19
+
20
+ class VAEStateDictConverter(StateDictConverter):
21
+ def __init__(self, has_encoder: bool = False, has_decoder: bool = False):
22
+ self.has_encoder = has_encoder
23
+ self.has_decoder = has_decoder
24
+
25
+ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
26
+ rename_dict = config["civitai"]["rename_dict"]
27
+ new_state_dict = {}
28
+ for key, param in state_dict.items():
29
+ if key not in rename_dict:
30
+ continue
31
+ new_key = rename_dict[key]
32
+ if "transformer_blocks" in new_key:
33
+ param = param.squeeze()
34
+ new_state_dict[new_key] = param
35
+ return new_state_dict
36
+
37
+ def _filter(self, state_dict: Dict[str, torch.Tensor]):
38
+ new_state_dict = {}
39
+ for key, param in state_dict.items():
40
+ if self.has_encoder and self.has_decoder:
41
+ new_state_dict[key] = param
42
+ elif self.has_encoder and key.startswith("encoder."):
43
+ new_state_dict[key[len("encoder.") :]] = param
44
+ elif self.has_decoder and key.startswith("decoder."):
45
+ new_state_dict[key[len("decoder.") :]] = param
46
+ return new_state_dict
47
+
48
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
49
+ assert self.has_decoder or self.has_encoder, "Either decoder or encoder must be present"
50
+ if (
51
+ "first_stage_model.decoder.conv_in.weight" in state_dict
52
+ or "first_stage_model.encoder.conv_in.weight" in state_dict
53
+ ):
54
+ state_dict = self._from_civitai(state_dict)
55
+ logger.info("use civitai format state dict")
56
+ else:
57
+ logger.info("use diffsynth format state dict")
58
+ return self._filter(state_dict)
59
+
60
+
61
+ class VAEAttentionBlock(nn.Module):
62
+ def __init__(
63
+ self,
64
+ num_attention_heads,
65
+ attention_head_dim,
66
+ in_channels,
67
+ num_layers=1,
68
+ norm_num_groups=32,
69
+ eps=1e-5,
70
+ device: str = "cuda:0",
71
+ dtype: torch.dtype = torch.float32,
72
+ ):
73
+ super().__init__()
74
+ inner_dim = num_attention_heads * attention_head_dim
75
+
76
+ self.norm = nn.GroupNorm(
77
+ num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True, device=device, dtype=dtype
78
+ )
79
+
80
+ self.transformer_blocks = nn.ModuleList(
81
+ [
82
+ Attention(
83
+ inner_dim,
84
+ num_attention_heads,
85
+ attention_head_dim,
86
+ bias_q=True,
87
+ bias_kv=True,
88
+ bias_out=True,
89
+ attn_implementation="xformers",
90
+ device=device,
91
+ dtype=dtype,
92
+ )
93
+ for d in range(num_layers)
94
+ ]
95
+ )
96
+
97
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
98
+ batch, _, height, width = hidden_states.shape
99
+ residual = hidden_states
100
+
101
+ hidden_states = self.norm(hidden_states)
102
+ inner_dim = hidden_states.shape[1]
103
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
104
+
105
+ for block in self.transformer_blocks:
106
+ hidden_states = block(hidden_states)
107
+
108
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
109
+ hidden_states = hidden_states + residual
110
+
111
+ return hidden_states, time_emb, text_emb, res_stack
112
+
113
+
114
+ class VAEDecoder(PreTrainedModel):
115
+ converter = VAEStateDictConverter(has_decoder=True)
116
+
117
+ def __init__(
118
+ self,
119
+ latent_channels: int = 4,
120
+ scaling_factor: float = 0.18215,
121
+ shift_factor: float = 0,
122
+ use_post_quant_conv: bool = True,
123
+ device: str = "cuda:0",
124
+ dtype: torch.dtype = torch.float32,
125
+ ):
126
+ super().__init__()
127
+ self.latent_channels = latent_channels
128
+ self.scaling_factor = scaling_factor
129
+ self.shift_factor = shift_factor
130
+ self.use_post_quant_conv = use_post_quant_conv
131
+ if use_post_quant_conv:
132
+ self.post_quant_conv = nn.Conv2d(
133
+ latent_channels, latent_channels, kernel_size=1, device=device, dtype=dtype
134
+ )
135
+ self.conv_in = nn.Conv2d(latent_channels, 512, kernel_size=3, padding=1, device=device, dtype=dtype)
136
+
137
+ self.blocks = nn.ModuleList(
138
+ [
139
+ # UNetMidBlock2D
140
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
141
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
142
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
143
+ # UpDecoderBlock2D
144
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
145
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
146
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
147
+ UpSampler(512, device=device, dtype=dtype),
148
+ # UpDecoderBlock2D
149
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
150
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
151
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
152
+ UpSampler(512, device=device, dtype=dtype),
153
+ # UpDecoderBlock2D
154
+ ResnetBlock(512, 256, eps=1e-6, device=device, dtype=dtype),
155
+ ResnetBlock(256, 256, eps=1e-6, device=device, dtype=dtype),
156
+ ResnetBlock(256, 256, eps=1e-6, device=device, dtype=dtype),
157
+ UpSampler(256, device=device, dtype=dtype),
158
+ # UpDecoderBlock2D
159
+ ResnetBlock(256, 128, eps=1e-6, device=device, dtype=dtype),
160
+ ResnetBlock(128, 128, eps=1e-6, device=device, dtype=dtype),
161
+ ResnetBlock(128, 128, eps=1e-6, device=device, dtype=dtype),
162
+ ]
163
+ )
164
+
165
+ self.conv_norm_out = nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6, device=device, dtype=dtype)
166
+ self.conv_act = nn.SiLU()
167
+ self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1, device=device, dtype=dtype)
168
+
169
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
170
+ original_dtype = sample.dtype
171
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
172
+ if tiled:
173
+ raise NotImplementedError()
174
+
175
+ # 1. pre-process
176
+ sample = sample / self.scaling_factor + self.shift_factor
177
+ if self.use_post_quant_conv:
178
+ sample = self.post_quant_conv(sample)
179
+ hidden_states = self.conv_in(sample)
180
+ time_emb = None
181
+ text_emb = None
182
+ res_stack = None
183
+
184
+ # 2. blocks
185
+ for i, block in enumerate(self.blocks):
186
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
187
+
188
+ # 3. output
189
+ hidden_states = self.conv_norm_out(hidden_states)
190
+ hidden_states = self.conv_act(hidden_states)
191
+ hidden_states = self.conv_out(hidden_states)
192
+ hidden_states = hidden_states.to(original_dtype)
193
+
194
+ return hidden_states
195
+
196
+ @classmethod
197
+ def from_state_dict(
198
+ cls,
199
+ state_dict: Dict[str, torch.Tensor],
200
+ device: str,
201
+ dtype: torch.dtype,
202
+ latent_channels: int = 4,
203
+ scaling_factor: float = 0.18215,
204
+ shift_factor: float = 0,
205
+ use_post_quant_conv: bool = True,
206
+ ):
207
+ with no_init_weights():
208
+ model = torch.nn.utils.skip_init(
209
+ cls,
210
+ latent_channels=latent_channels,
211
+ scaling_factor=scaling_factor,
212
+ shift_factor=shift_factor,
213
+ use_post_quant_conv=use_post_quant_conv,
214
+ device=device,
215
+ dtype=dtype,
216
+ )
217
+ model.load_state_dict(state_dict)
218
+ return model
219
+
220
+ @classmethod
221
+ def from_pretrained(cls, pretrained_model_path: str | os.PathLike, **kwargs):
222
+ raise NotImplementedError()
223
+
224
+
225
+ class VAEEncoder(PreTrainedModel):
226
+ converter = VAEStateDictConverter(has_encoder=True)
227
+
228
+ def __init__(
229
+ self,
230
+ latent_channels: int = 4,
231
+ scaling_factor: float = 0.18215,
232
+ shift_factor: float = 0,
233
+ use_quant_conv: bool = True,
234
+ device: str = "cuda:0",
235
+ dtype: torch.dtype = torch.float32,
236
+ ):
237
+ super().__init__()
238
+ self.latent_channels = latent_channels
239
+ self.scaling_factor = scaling_factor
240
+ self.shift_factor = shift_factor
241
+ self.use_quant_conv = use_quant_conv
242
+ if use_quant_conv:
243
+ self.quant_conv = nn.Conv2d(
244
+ 2 * latent_channels, 2 * latent_channels, kernel_size=1, device=device, dtype=dtype
245
+ )
246
+ self.conv_in = nn.Conv2d(3, 128, kernel_size=3, padding=1, device=device, dtype=dtype)
247
+
248
+ self.blocks = nn.ModuleList(
249
+ [
250
+ # DownEncoderBlock2D
251
+ ResnetBlock(128, 128, eps=1e-6, device=device, dtype=dtype),
252
+ ResnetBlock(128, 128, eps=1e-6, device=device, dtype=dtype),
253
+ DownSampler(128, padding=0, extra_padding=True, device=device, dtype=dtype),
254
+ # DownEncoderBlock2D
255
+ ResnetBlock(128, 256, eps=1e-6, device=device, dtype=dtype),
256
+ ResnetBlock(256, 256, eps=1e-6, device=device, dtype=dtype),
257
+ DownSampler(256, padding=0, extra_padding=True, device=device, dtype=dtype),
258
+ # DownEncoderBlock2D
259
+ ResnetBlock(256, 512, eps=1e-6, device=device, dtype=dtype),
260
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
261
+ DownSampler(512, padding=0, extra_padding=True, device=device, dtype=dtype),
262
+ # DownEncoderBlock2D
263
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
264
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
265
+ # UNetMidBlock2D
266
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
267
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
268
+ ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
269
+ ]
270
+ )
271
+
272
+ self.conv_norm_out = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, device=device, dtype=dtype)
273
+ self.conv_act = nn.SiLU()
274
+ self.conv_out = nn.Conv2d(512, 2 * latent_channels, kernel_size=3, padding=1, device=device, dtype=dtype)
275
+
276
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
277
+ original_dtype = sample.dtype
278
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
279
+ if tiled:
280
+ raise NotImplementedError()
281
+
282
+ # 1. pre-process
283
+ hidden_states = self.conv_in(sample)
284
+ time_emb = None
285
+ text_emb = None
286
+ res_stack = None
287
+
288
+ # 2. blocks
289
+ for i, block in enumerate(self.blocks):
290
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
291
+
292
+ # 3. output
293
+ hidden_states = self.conv_norm_out(hidden_states)
294
+ hidden_states = self.conv_act(hidden_states)
295
+ hidden_states = self.conv_out(hidden_states)
296
+ if self.use_quant_conv:
297
+ hidden_states = self.quant_conv(hidden_states)
298
+ hidden_states = hidden_states[:, : self.latent_channels]
299
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
300
+ hidden_states = hidden_states.to(original_dtype)
301
+ return hidden_states
302
+
303
+ @classmethod
304
+ def from_state_dict(
305
+ cls,
306
+ state_dict: Dict[str, torch.Tensor],
307
+ device: str,
308
+ dtype: torch.dtype,
309
+ latent_channels: int = 4,
310
+ scaling_factor: float = 0.18215,
311
+ shift_factor: float = 0,
312
+ use_quant_conv: bool = True,
313
+ ):
314
+ with no_init_weights():
315
+ model = torch.nn.utils.skip_init(
316
+ cls,
317
+ latent_channels=latent_channels,
318
+ scaling_factor=scaling_factor,
319
+ shift_factor=shift_factor,
320
+ use_quant_conv=use_quant_conv,
321
+ device=device,
322
+ dtype=dtype,
323
+ )
324
+ model.load_state_dict(state_dict)
325
+ return model
326
+
327
+ @classmethod
328
+ def from_pretrained(cls, pretrained_model_path: str | os.PathLike, **kwargs):
329
+ raise NotImplementedError()
330
+
331
+
332
+ class VAE(PreTrainedModel):
333
+ converter = VAEStateDictConverter(has_encoder=True, has_decoder=True)
334
+
335
+ def __init__(
336
+ self,
337
+ latent_channels: int = 4,
338
+ scaling_factor: float = 0.18215,
339
+ shift_factor: float = 0,
340
+ use_quant_conv: bool = True,
341
+ use_post_quant_conv: bool = True,
342
+ device: str = "cuda:0",
343
+ dtype: torch.dtype = torch.float32,
344
+ ):
345
+ super().__init__()
346
+ self.encoder = VAEEncoder(
347
+ latent_channels=latent_channels,
348
+ scaling_factor=scaling_factor,
349
+ shift_factor=shift_factor,
350
+ use_quant_conv=use_quant_conv,
351
+ device=device,
352
+ dtype=dtype,
353
+ )
354
+ self.decoder = VAEDecoder(
355
+ latent_channels=latent_channels,
356
+ scaling_factor=scaling_factor,
357
+ shift_factor=shift_factor,
358
+ use_post_quant_conv=use_post_quant_conv,
359
+ device=device,
360
+ dtype=dtype,
361
+ )
362
+
363
+ def encode(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
364
+ return self.encoder(sample, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, **kwargs)
365
+
366
+ def decode(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
367
+ return self.decoder(sample, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, **kwargs)
368
+
369
+ @classmethod
370
+ def from_state_dict(
371
+ cls,
372
+ state_dict: Dict[str, torch.Tensor],
373
+ device: str,
374
+ dtype: torch.dtype,
375
+ latent_channels: int = 4,
376
+ scaling_factor: float = 0.18215,
377
+ shift_factor: float = 0,
378
+ use_quant_conv: bool = True,
379
+ use_post_quant_conv: bool = True,
380
+ ):
381
+ with no_init_weights():
382
+ model = torch.nn.utils.skip_init(
383
+ cls,
384
+ latent_channels=latent_channels,
385
+ scaling_factor=scaling_factor,
386
+ shift_factor=shift_factor,
387
+ use_quant_conv=use_quant_conv,
388
+ use_post_quant_conv=use_post_quant_conv,
389
+ device=device,
390
+ dtype=dtype,
391
+ )
392
+ model.load_state_dict(state_dict)
393
+ return model
@@ -0,0 +1,14 @@
1
+ from .flux_dit import FluxDiT, config as flux_dit_config
2
+ from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2, config as flux_text_encoder_config
3
+ from .flux_vae import FluxVAEDecoder, FluxVAEEncoder, config as flux_vae_config
4
+
5
+ __all__ = [
6
+ "FluxDiT",
7
+ "FluxTextEncoder1",
8
+ "FluxTextEncoder2",
9
+ "FluxVAEDecoder",
10
+ "FluxVAEEncoder",
11
+ "flux_dit_config",
12
+ "flux_text_encoder_config",
13
+ "flux_vae_config",
14
+ ]