diffsynth-engine 0.3.6.dev8__py3-none-any.whl → 0.3.6.dev10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +10 -8
- diffsynth_engine/configs/__init__.py +23 -0
- diffsynth_engine/configs/controlnet.py +17 -0
- diffsynth_engine/configs/pipeline.py +206 -0
- diffsynth_engine/models/basic/attention.py +43 -4
- diffsynth_engine/models/flux/flux_controlnet.py +8 -5
- diffsynth_engine/models/flux/flux_dit.py +22 -16
- diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/sd/sd_controlnet.py +2 -4
- diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
- diffsynth_engine/models/wan/wan_dit.py +15 -15
- diffsynth_engine/pipelines/__init__.py +5 -8
- diffsynth_engine/pipelines/base.py +14 -65
- diffsynth_engine/pipelines/flux_image.py +85 -158
- diffsynth_engine/pipelines/sd_image.py +30 -64
- diffsynth_engine/pipelines/sdxl_image.py +39 -71
- diffsynth_engine/pipelines/wan_video.py +66 -105
- diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_reference_tool.py +21 -5
- diffsynth_engine/tools/flux_replace_tool.py +15 -3
- diffsynth_engine/utils/fp8_linear.py +14 -5
- diffsynth_engine/utils/parallel.py +1 -1
- diffsynth_engine/utils/platform.py +9 -1
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/RECORD +30 -27
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import os
|
|
3
2
|
import torch
|
|
4
3
|
import numpy as np
|
|
5
4
|
from einops import repeat
|
|
6
|
-
from dataclasses import dataclass
|
|
7
5
|
from typing import Callable, Dict, Optional, List
|
|
8
6
|
from tqdm import tqdm
|
|
9
7
|
from PIL import Image, ImageOps
|
|
10
8
|
|
|
9
|
+
from diffsynth_engine.configs import SDPipelineConfig
|
|
11
10
|
from diffsynth_engine.models.base import split_suffix
|
|
12
11
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
13
12
|
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
|
|
@@ -84,17 +83,6 @@ def convert_diffusers_name_to_compvis(key):
|
|
|
84
83
|
return key
|
|
85
84
|
|
|
86
85
|
|
|
87
|
-
@dataclass
|
|
88
|
-
class SDModelConfig:
|
|
89
|
-
unet_path: str | os.PathLike
|
|
90
|
-
clip_path: Optional[str | os.PathLike] = None
|
|
91
|
-
vae_path: Optional[str | os.PathLike] = None
|
|
92
|
-
|
|
93
|
-
unet_dtype: torch.dtype = torch.float16
|
|
94
|
-
clip_dtype: torch.dtype = torch.float16
|
|
95
|
-
vae_dtype: torch.dtype = torch.float32
|
|
96
|
-
|
|
97
|
-
|
|
98
86
|
class SDLoRAConverter(LoRAStateDictConverter):
|
|
99
87
|
def _replace_kohya_te_key(self, key):
|
|
100
88
|
key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
|
|
@@ -151,27 +139,22 @@ class SDImagePipeline(BasePipeline):
|
|
|
151
139
|
|
|
152
140
|
def __init__(
|
|
153
141
|
self,
|
|
154
|
-
config:
|
|
142
|
+
config: SDPipelineConfig,
|
|
155
143
|
tokenizer: CLIPTokenizer,
|
|
156
144
|
text_encoder: SDTextEncoder,
|
|
157
145
|
unet: SDUNet,
|
|
158
146
|
vae_decoder: SDVAEDecoder,
|
|
159
147
|
vae_encoder: SDVAEEncoder,
|
|
160
|
-
batch_cfg: bool = True,
|
|
161
|
-
vae_tiled: bool = False,
|
|
162
|
-
vae_tile_size: int = 256,
|
|
163
|
-
vae_tile_stride: int = 256,
|
|
164
|
-
device: str = "cuda",
|
|
165
|
-
dtype: torch.dtype = torch.float16,
|
|
166
148
|
):
|
|
167
149
|
super().__init__(
|
|
168
|
-
vae_tiled=vae_tiled,
|
|
169
|
-
vae_tile_size=vae_tile_size,
|
|
170
|
-
vae_tile_stride=vae_tile_stride,
|
|
171
|
-
device=device,
|
|
172
|
-
dtype=
|
|
150
|
+
vae_tiled=config.vae_tiled,
|
|
151
|
+
vae_tile_size=config.vae_tile_size,
|
|
152
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
153
|
+
device=config.device,
|
|
154
|
+
dtype=config.model_dtype,
|
|
173
155
|
)
|
|
174
156
|
self.config = config
|
|
157
|
+
# sampler
|
|
175
158
|
self.noise_scheduler = ScaledLinearScheduler()
|
|
176
159
|
self.sampler = EulerSampler()
|
|
177
160
|
# models
|
|
@@ -180,71 +163,54 @@ class SDImagePipeline(BasePipeline):
|
|
|
180
163
|
self.unet = unet
|
|
181
164
|
self.vae_decoder = vae_decoder
|
|
182
165
|
self.vae_encoder = vae_encoder
|
|
183
|
-
self.batch_cfg = batch_cfg
|
|
184
166
|
self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
|
185
167
|
|
|
186
168
|
@classmethod
|
|
187
|
-
def from_pretrained(
|
|
188
|
-
cls,
|
|
189
|
-
model_path_or_config: str | os.PathLike | SDModelConfig,
|
|
190
|
-
batch_cfg: bool = True,
|
|
191
|
-
vae_tiled: bool = False,
|
|
192
|
-
vae_tile_size: int = 256,
|
|
193
|
-
vae_tile_stride: int = 256,
|
|
194
|
-
device: str = "cuda",
|
|
195
|
-
dtype: torch.dtype = torch.float16,
|
|
196
|
-
offload_mode: str | None = None,
|
|
197
|
-
) -> "SDImagePipeline":
|
|
169
|
+
def from_pretrained(cls, model_path_or_config: SDPipelineConfig) -> "SDImagePipeline":
|
|
198
170
|
if isinstance(model_path_or_config, str):
|
|
199
|
-
|
|
171
|
+
config = SDPipelineConfig(model_path=model_path_or_config)
|
|
200
172
|
else:
|
|
201
|
-
|
|
173
|
+
config = model_path_or_config
|
|
202
174
|
|
|
203
|
-
logger.info(f"loading state dict from {
|
|
204
|
-
unet_state_dict = cls.load_model_checkpoint(
|
|
175
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
176
|
+
unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
205
177
|
|
|
206
|
-
if
|
|
207
|
-
logger.info(f"loading state dict from {
|
|
208
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
178
|
+
if config.vae_path is not None:
|
|
179
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
180
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
209
181
|
else:
|
|
210
182
|
vae_state_dict = unet_state_dict
|
|
211
183
|
|
|
212
|
-
if
|
|
213
|
-
logger.info(f"loading state dict from {
|
|
214
|
-
clip_state_dict = cls.load_model_checkpoint(
|
|
184
|
+
if config.clip_path is not None:
|
|
185
|
+
logger.info(f"loading state dict from {config.clip_path} ...")
|
|
186
|
+
clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
|
|
215
187
|
else:
|
|
216
188
|
clip_state_dict = unet_state_dict
|
|
217
189
|
|
|
218
|
-
init_device = "cpu" if offload_mode else device
|
|
190
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
219
191
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
220
192
|
with LoRAContext():
|
|
221
|
-
text_encoder = SDTextEncoder.from_state_dict(
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
|
|
193
|
+
text_encoder = SDTextEncoder.from_state_dict(clip_state_dict, device=init_device, dtype=config.clip_dtype)
|
|
194
|
+
unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
|
|
225
195
|
vae_decoder = SDVAEDecoder.from_state_dict(
|
|
226
|
-
vae_state_dict, device=init_device, dtype=
|
|
196
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
227
197
|
)
|
|
228
198
|
vae_encoder = SDVAEEncoder.from_state_dict(
|
|
229
|
-
vae_state_dict, device=init_device, dtype=
|
|
199
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
230
200
|
)
|
|
231
201
|
|
|
232
202
|
pipe = cls(
|
|
233
|
-
config=
|
|
203
|
+
config=config,
|
|
234
204
|
tokenizer=tokenizer,
|
|
235
205
|
text_encoder=text_encoder,
|
|
236
206
|
unet=unet,
|
|
237
207
|
vae_decoder=vae_decoder,
|
|
238
208
|
vae_encoder=vae_encoder,
|
|
239
|
-
batch_cfg=batch_cfg,
|
|
240
|
-
vae_tiled=vae_tiled,
|
|
241
|
-
vae_tile_size=vae_tile_size,
|
|
242
|
-
vae_tile_stride=vae_tile_stride,
|
|
243
|
-
device=device,
|
|
244
|
-
dtype=dtype,
|
|
245
209
|
)
|
|
246
|
-
|
|
247
|
-
|
|
210
|
+
pipe.eval()
|
|
211
|
+
|
|
212
|
+
if config.offload_mode is not None:
|
|
213
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
248
214
|
return pipe
|
|
249
215
|
|
|
250
216
|
@classmethod
|
|
@@ -439,7 +405,7 @@ class SDImagePipeline(BasePipeline):
|
|
|
439
405
|
controlnet_params=controlnet_params,
|
|
440
406
|
current_step=i,
|
|
441
407
|
total_step=len(timesteps),
|
|
442
|
-
batch_cfg=self.batch_cfg,
|
|
408
|
+
batch_cfg=self.config.batch_cfg,
|
|
443
409
|
)
|
|
444
410
|
# Denoise
|
|
445
411
|
latents = self.sampler.step(latents, noise_pred, i)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import re
|
|
3
2
|
import torch
|
|
4
3
|
import numpy as np
|
|
@@ -6,8 +5,8 @@ from einops import repeat
|
|
|
6
5
|
from typing import Callable, Dict, Optional, List
|
|
7
6
|
from tqdm import tqdm
|
|
8
7
|
from PIL import Image, ImageOps
|
|
9
|
-
from dataclasses import dataclass
|
|
10
8
|
|
|
9
|
+
from diffsynth_engine.configs import SDXLPipelineConfig
|
|
11
10
|
from diffsynth_engine.models.base import split_suffix
|
|
12
11
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
13
12
|
from diffsynth_engine.models.basic.timestep import TemporalTimesteps
|
|
@@ -102,25 +101,12 @@ class SDXLLoRAConverter(LoRAStateDictConverter):
|
|
|
102
101
|
raise ValueError(f"Unsupported key: {key}")
|
|
103
102
|
|
|
104
103
|
|
|
105
|
-
@dataclass
|
|
106
|
-
class SDXLModelConfig:
|
|
107
|
-
unet_path: str | os.PathLike
|
|
108
|
-
clip_l_path: Optional[str | os.PathLike] = None
|
|
109
|
-
clip_g_path: Optional[str | os.PathLike] = None
|
|
110
|
-
vae_path: Optional[str | os.PathLike] = None
|
|
111
|
-
|
|
112
|
-
unet_dtype: torch.dtype = torch.float16
|
|
113
|
-
clip_l_dtype: torch.dtype = torch.float16
|
|
114
|
-
clip_g_dtype: torch.dtype = torch.float16
|
|
115
|
-
vae_dtype: torch.dtype = torch.float32
|
|
116
|
-
|
|
117
|
-
|
|
118
104
|
class SDXLImagePipeline(BasePipeline):
|
|
119
105
|
lora_converter = SDXLLoRAConverter()
|
|
120
106
|
|
|
121
107
|
def __init__(
|
|
122
108
|
self,
|
|
123
|
-
config:
|
|
109
|
+
config: SDXLPipelineConfig,
|
|
124
110
|
tokenizer: CLIPTokenizer,
|
|
125
111
|
tokenizer_2: CLIPTokenizer,
|
|
126
112
|
text_encoder: SDXLTextEncoder,
|
|
@@ -128,21 +114,16 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
128
114
|
unet: SDXLUNet,
|
|
129
115
|
vae_decoder: SDXLVAEDecoder,
|
|
130
116
|
vae_encoder: SDXLVAEEncoder,
|
|
131
|
-
batch_cfg: bool = True,
|
|
132
|
-
vae_tiled: bool = False,
|
|
133
|
-
vae_tile_size: int = 256,
|
|
134
|
-
vae_tile_stride: int = 256,
|
|
135
|
-
device: str = "cuda",
|
|
136
|
-
dtype: torch.dtype = torch.float16,
|
|
137
117
|
):
|
|
138
118
|
super().__init__(
|
|
139
|
-
vae_tiled=vae_tiled,
|
|
140
|
-
vae_tile_size=vae_tile_size,
|
|
141
|
-
vae_tile_stride=vae_tile_stride,
|
|
142
|
-
device=device,
|
|
143
|
-
dtype=
|
|
119
|
+
vae_tiled=config.vae_tiled,
|
|
120
|
+
vae_tile_size=config.vae_tile_size,
|
|
121
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
122
|
+
device=config.device,
|
|
123
|
+
dtype=config.model_dtype,
|
|
144
124
|
)
|
|
145
125
|
self.config = config
|
|
126
|
+
# sampler
|
|
146
127
|
self.noise_scheduler = ScaledLinearScheduler()
|
|
147
128
|
self.sampler = EulerSampler()
|
|
148
129
|
# models
|
|
@@ -154,71 +135,62 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
154
135
|
self.vae_decoder = vae_decoder
|
|
155
136
|
self.vae_encoder = vae_encoder
|
|
156
137
|
self.add_time_proj = TemporalTimesteps(
|
|
157
|
-
num_channels=256,
|
|
138
|
+
num_channels=256,
|
|
139
|
+
flip_sin_to_cos=True,
|
|
140
|
+
downscale_freq_shift=0,
|
|
141
|
+
device=config.device,
|
|
142
|
+
dtype=config.model_dtype,
|
|
158
143
|
)
|
|
159
|
-
self.batch_cfg = batch_cfg
|
|
160
144
|
self.model_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
|
161
145
|
|
|
162
146
|
@classmethod
|
|
163
|
-
def from_pretrained(
|
|
164
|
-
cls,
|
|
165
|
-
model_path_or_config: str | os.PathLike | SDXLModelConfig,
|
|
166
|
-
batch_cfg: bool = True,
|
|
167
|
-
vae_tiled: bool = False,
|
|
168
|
-
vae_tile_size: int = 256,
|
|
169
|
-
vae_tile_stride: int = 256,
|
|
170
|
-
device: str = "cuda",
|
|
171
|
-
dtype: torch.dtype = torch.float16,
|
|
172
|
-
offload_mode: str | None = None,
|
|
173
|
-
) -> "SDXLImagePipeline":
|
|
147
|
+
def from_pretrained(cls, model_path_or_config: SDXLPipelineConfig) -> "SDXLImagePipeline":
|
|
174
148
|
if isinstance(model_path_or_config, str):
|
|
175
|
-
|
|
176
|
-
unet_path=model_path_or_config, unet_dtype=dtype, clip_l_dtype=dtype, clip_g_dtype=dtype
|
|
177
|
-
)
|
|
149
|
+
config = SDXLPipelineConfig(model_path=model_path_or_config)
|
|
178
150
|
else:
|
|
179
|
-
|
|
151
|
+
config = model_path_or_config
|
|
180
152
|
|
|
181
|
-
logger.info(f"loading state dict from {
|
|
182
|
-
unet_state_dict = cls.load_model_checkpoint(
|
|
153
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
154
|
+
unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
183
155
|
|
|
184
|
-
if
|
|
185
|
-
logger.info(f"loading state dict from {
|
|
186
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
156
|
+
if config.vae_path is not None:
|
|
157
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
158
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
187
159
|
else:
|
|
188
160
|
vae_state_dict = unet_state_dict
|
|
189
161
|
|
|
190
|
-
if
|
|
191
|
-
logger.info(f"loading state dict from {
|
|
192
|
-
clip_l_state_dict = cls.load_model_checkpoint(
|
|
162
|
+
if config.clip_l_path is not None:
|
|
163
|
+
logger.info(f"loading state dict from {config.clip_l_path} ...")
|
|
164
|
+
clip_l_state_dict = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype)
|
|
193
165
|
else:
|
|
194
166
|
clip_l_state_dict = unet_state_dict
|
|
195
167
|
|
|
196
|
-
if
|
|
197
|
-
logger.info(f"loading state dict from {
|
|
198
|
-
clip_g_state_dict = cls.load_model_checkpoint(
|
|
168
|
+
if config.clip_g_path is not None:
|
|
169
|
+
logger.info(f"loading state dict from {config.clip_g_path} ...")
|
|
170
|
+
clip_g_state_dict = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype)
|
|
199
171
|
else:
|
|
200
172
|
clip_g_state_dict = unet_state_dict
|
|
201
173
|
|
|
202
|
-
init_device = "cpu" if offload_mode else device
|
|
174
|
+
init_device = "cpu" if config.offload_mode else config.device
|
|
203
175
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
204
176
|
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
|
|
205
177
|
with LoRAContext():
|
|
206
178
|
text_encoder = SDXLTextEncoder.from_state_dict(
|
|
207
|
-
clip_l_state_dict, device=init_device, dtype=
|
|
179
|
+
clip_l_state_dict, device=init_device, dtype=config.clip_l_dtype
|
|
208
180
|
)
|
|
209
181
|
text_encoder_2 = SDXLTextEncoder2.from_state_dict(
|
|
210
|
-
clip_g_state_dict, device=init_device, dtype=
|
|
182
|
+
clip_g_state_dict, device=init_device, dtype=config.clip_g_dtype
|
|
211
183
|
)
|
|
212
|
-
unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=
|
|
184
|
+
unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
|
|
213
185
|
vae_decoder = SDXLVAEDecoder.from_state_dict(
|
|
214
|
-
vae_state_dict, device=init_device, dtype=
|
|
186
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
215
187
|
)
|
|
216
188
|
vae_encoder = SDXLVAEEncoder.from_state_dict(
|
|
217
|
-
vae_state_dict, device=init_device, dtype=
|
|
189
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
218
190
|
)
|
|
219
191
|
|
|
220
192
|
pipe = cls(
|
|
221
|
-
config=
|
|
193
|
+
config=config,
|
|
222
194
|
tokenizer=tokenizer,
|
|
223
195
|
tokenizer_2=tokenizer_2,
|
|
224
196
|
text_encoder=text_encoder,
|
|
@@ -226,15 +198,11 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
226
198
|
unet=unet,
|
|
227
199
|
vae_decoder=vae_decoder,
|
|
228
200
|
vae_encoder=vae_encoder,
|
|
229
|
-
batch_cfg=batch_cfg,
|
|
230
|
-
vae_tiled=vae_tiled,
|
|
231
|
-
vae_tile_size=vae_tile_size,
|
|
232
|
-
vae_tile_stride=vae_tile_stride,
|
|
233
|
-
device=device,
|
|
234
|
-
dtype=dtype,
|
|
235
201
|
)
|
|
236
|
-
|
|
237
|
-
|
|
202
|
+
pipe.eval()
|
|
203
|
+
|
|
204
|
+
if config.offload_mode is not None:
|
|
205
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
238
206
|
return pipe
|
|
239
207
|
|
|
240
208
|
@classmethod
|
|
@@ -517,7 +485,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
517
485
|
controlnet_params=controlnet_params,
|
|
518
486
|
current_step=i,
|
|
519
487
|
total_step=len(timesteps),
|
|
520
|
-
batch_cfg=self.batch_cfg,
|
|
488
|
+
batch_cfg=self.config.batch_cfg,
|
|
521
489
|
)
|
|
522
490
|
# Denoise
|
|
523
491
|
latents = self.sampler.step(latents, noise_pred, i)
|
|
@@ -2,11 +2,11 @@ import torch
|
|
|
2
2
|
import torch.distributed as dist
|
|
3
3
|
import numpy as np
|
|
4
4
|
from einops import rearrange
|
|
5
|
-
from dataclasses import dataclass
|
|
6
5
|
from typing import Callable, List, Tuple, Optional
|
|
7
6
|
from tqdm import tqdm
|
|
8
7
|
from PIL import Image
|
|
9
8
|
|
|
9
|
+
from diffsynth_engine.configs import WanPipelineConfig
|
|
10
10
|
from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler
|
|
11
11
|
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
12
12
|
from diffsynth_engine.models.wan.wan_dit import WanDiT
|
|
@@ -18,6 +18,7 @@ from diffsynth_engine.tokenizers import WanT5Tokenizer
|
|
|
18
18
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
19
19
|
from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
|
|
20
20
|
from diffsynth_engine.utils.download import fetch_model
|
|
21
|
+
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
|
|
21
22
|
from diffsynth_engine.utils.parallel import ParallelWrapper
|
|
22
23
|
from diffsynth_engine.utils import logging
|
|
23
24
|
|
|
@@ -25,26 +26,6 @@ from diffsynth_engine.utils import logging
|
|
|
25
26
|
logger = logging.get_logger(__name__)
|
|
26
27
|
|
|
27
28
|
|
|
28
|
-
@dataclass
|
|
29
|
-
class WanModelConfig:
|
|
30
|
-
model_path: Optional[str] = None
|
|
31
|
-
vae_path: Optional[str] = None
|
|
32
|
-
t5_path: Optional[str] = None
|
|
33
|
-
image_encoder_path: Optional[str] = None
|
|
34
|
-
|
|
35
|
-
vae_dtype: torch.dtype = torch.float32
|
|
36
|
-
dit_dtype: torch.dtype = torch.bfloat16
|
|
37
|
-
t5_dtype: torch.dtype = torch.bfloat16
|
|
38
|
-
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
39
|
-
|
|
40
|
-
dit_attn_impl: Optional[str] = "auto"
|
|
41
|
-
|
|
42
|
-
sp_ulysses_degree: Optional[int] = None
|
|
43
|
-
sp_ring_degree: Optional[int] = None
|
|
44
|
-
tp_degree: Optional[int] = None
|
|
45
|
-
use_fsdp: bool = False
|
|
46
|
-
|
|
47
|
-
|
|
48
29
|
class WanLoRAConverter(LoRAStateDictConverter):
|
|
49
30
|
def _from_diffsynth(self, state_dict):
|
|
50
31
|
dit_dict = {}
|
|
@@ -129,42 +110,40 @@ class WanVideoPipeline(BasePipeline):
|
|
|
129
110
|
|
|
130
111
|
def __init__(
|
|
131
112
|
self,
|
|
132
|
-
config:
|
|
113
|
+
config: WanPipelineConfig,
|
|
133
114
|
tokenizer: WanT5Tokenizer,
|
|
134
115
|
text_encoder: WanTextEncoder,
|
|
135
116
|
dit: WanDiT,
|
|
136
117
|
vae: WanVideoVAE,
|
|
137
118
|
image_encoder: WanImageEncoder,
|
|
138
|
-
shift: float = 5.0,
|
|
139
|
-
batch_cfg: bool = False,
|
|
140
|
-
vae_tiled: bool = True,
|
|
141
|
-
vae_tile_size: Tuple[int, int] = (34, 34),
|
|
142
|
-
vae_tile_stride: Tuple[int, int] = (18, 16),
|
|
143
|
-
device="cuda",
|
|
144
|
-
dtype=torch.bfloat16,
|
|
145
119
|
):
|
|
146
120
|
super().__init__(
|
|
147
|
-
vae_tiled=vae_tiled,
|
|
148
|
-
vae_tile_size=vae_tile_size,
|
|
149
|
-
vae_tile_stride=vae_tile_stride,
|
|
150
|
-
device=device,
|
|
151
|
-
dtype=
|
|
121
|
+
vae_tiled=config.vae_tiled,
|
|
122
|
+
vae_tile_size=config.vae_tile_size,
|
|
123
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
124
|
+
device=config.device,
|
|
125
|
+
dtype=config.model_dtype,
|
|
152
126
|
)
|
|
153
127
|
self.config = config
|
|
154
|
-
|
|
128
|
+
# sampler
|
|
129
|
+
self.noise_scheduler = RecifitedFlowScheduler(
|
|
130
|
+
shift=config.shift if config.shift is not None else 5.0,
|
|
131
|
+
sigma_min=0.001,
|
|
132
|
+
sigma_max=0.999,
|
|
133
|
+
)
|
|
155
134
|
self.sampler = FlowMatchEulerSampler()
|
|
135
|
+
# models
|
|
156
136
|
self.tokenizer = tokenizer
|
|
157
137
|
self.text_encoder = text_encoder
|
|
158
138
|
self.dit = dit
|
|
159
139
|
self.vae = vae
|
|
160
140
|
self.image_encoder = image_encoder
|
|
161
|
-
self.batch_cfg = batch_cfg
|
|
162
141
|
self.model_names = ["text_encoder", "dit", "vae", "image_encoder"]
|
|
163
142
|
|
|
164
143
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
165
|
-
assert self.config.tp_degree is None, (
|
|
144
|
+
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
166
145
|
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
167
|
-
"set tp_degree=None during pipeline initialization"
|
|
146
|
+
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
168
147
|
)
|
|
169
148
|
assert not (self.config.use_fsdp and fused), (
|
|
170
149
|
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
@@ -228,7 +207,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
228
207
|
tile_size=self.vae_tile_size,
|
|
229
208
|
tile_stride=self.vae_tile_stride,
|
|
230
209
|
)
|
|
231
|
-
latents = latents.to(dtype=self.config.
|
|
210
|
+
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
232
211
|
return latents
|
|
233
212
|
|
|
234
213
|
def decode_video(self, latents, progress_callback=None) -> List[torch.Tensor]:
|
|
@@ -241,7 +220,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
241
220
|
tile_stride=self.vae_tile_stride,
|
|
242
221
|
progress_callback=progress_callback,
|
|
243
222
|
)
|
|
244
|
-
videos = [video.to(dtype=self.config.
|
|
223
|
+
videos = [video.to(dtype=self.config.model_dtype, device=self.device) for video in videos]
|
|
245
224
|
return videos
|
|
246
225
|
|
|
247
226
|
def predict_noise_with_cfg(
|
|
@@ -301,7 +280,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
301
280
|
return noise_pred
|
|
302
281
|
|
|
303
282
|
def predict_noise(self, latents, image_clip_feature, image_y, timestep, context):
|
|
304
|
-
latents = latents.to(dtype=self.config.
|
|
283
|
+
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
|
|
305
284
|
|
|
306
285
|
noise_pred = self.dit(
|
|
307
286
|
x=latents,
|
|
@@ -386,7 +365,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
386
365
|
self.load_models_to_device(["dit"])
|
|
387
366
|
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
388
367
|
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
|
|
389
|
-
timestep = timestep.unsqueeze(0).to(dtype=self.config.
|
|
368
|
+
timestep = timestep.unsqueeze(0).to(dtype=self.config.model_dtype, device=self.device)
|
|
390
369
|
# Classifier-free guidance
|
|
391
370
|
noise_pred = self.predict_noise_with_cfg(
|
|
392
371
|
latents=latents,
|
|
@@ -396,7 +375,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
396
375
|
image_clip_feature=image_clip_feature,
|
|
397
376
|
image_y=image_y,
|
|
398
377
|
cfg_scale=cfg_scale,
|
|
399
|
-
batch_cfg=self.batch_cfg,
|
|
378
|
+
batch_cfg=self.config.batch_cfg,
|
|
400
379
|
)
|
|
401
380
|
# Scheduler
|
|
402
381
|
latents = self.sampler.step(latents, noise_pred, i)
|
|
@@ -410,58 +389,43 @@ class WanVideoPipeline(BasePipeline):
|
|
|
410
389
|
return frames
|
|
411
390
|
|
|
412
391
|
@classmethod
|
|
413
|
-
def from_pretrained(
|
|
414
|
-
cls,
|
|
415
|
-
model_path_or_config: str | WanModelConfig,
|
|
416
|
-
shift: float | None = None,
|
|
417
|
-
batch_cfg: bool = False,
|
|
418
|
-
vae_tiled: bool = True,
|
|
419
|
-
vae_tile_size: Tuple[int, int] = (34, 34),
|
|
420
|
-
vae_tile_stride: Tuple[int, int] = (18, 16),
|
|
421
|
-
device: str = "cuda",
|
|
422
|
-
dtype: torch.dtype = torch.bfloat16,
|
|
423
|
-
offload_mode: str | None = None,
|
|
424
|
-
parallelism: int = 1,
|
|
425
|
-
use_cfg_parallel: bool = False,
|
|
426
|
-
) -> "WanVideoPipeline":
|
|
392
|
+
def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPipeline":
|
|
427
393
|
if isinstance(model_path_or_config, str):
|
|
428
|
-
|
|
394
|
+
config = WanPipelineConfig(model_path=model_path_or_config)
|
|
429
395
|
else:
|
|
430
|
-
|
|
396
|
+
config = model_path_or_config
|
|
431
397
|
|
|
432
|
-
if
|
|
433
|
-
|
|
434
|
-
if
|
|
435
|
-
|
|
436
|
-
if model_config.vae_path is None:
|
|
437
|
-
model_config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
|
|
398
|
+
if config.t5_path is None:
|
|
399
|
+
config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
|
|
400
|
+
if config.vae_path is None:
|
|
401
|
+
config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
|
|
438
402
|
|
|
439
|
-
logger.info(f"loading state dict from {
|
|
440
|
-
dit_state_dict = cls.load_model_checkpoint(
|
|
403
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
404
|
+
dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
441
405
|
|
|
442
|
-
logger.info(f"loading state dict from {
|
|
443
|
-
t5_state_dict = cls.load_model_checkpoint(
|
|
406
|
+
logger.info(f"loading state dict from {config.t5_path} ...")
|
|
407
|
+
t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
|
|
444
408
|
|
|
445
|
-
logger.info(f"loading state dict from {
|
|
446
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
409
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
410
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
447
411
|
|
|
448
|
-
init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
|
|
412
|
+
init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
|
|
449
413
|
tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
|
|
450
|
-
text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=
|
|
451
|
-
vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=
|
|
414
|
+
text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
|
|
415
|
+
vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
|
|
452
416
|
|
|
453
417
|
image_encoder = None
|
|
454
|
-
if
|
|
455
|
-
logger.info(f"loading state dict from {
|
|
418
|
+
if config.image_encoder_path is not None:
|
|
419
|
+
logger.info(f"loading state dict from {config.image_encoder_path} ...")
|
|
456
420
|
image_encoder_state_dict = cls.load_model_checkpoint(
|
|
457
|
-
|
|
421
|
+
config.image_encoder_path,
|
|
458
422
|
device="cpu",
|
|
459
|
-
dtype=
|
|
423
|
+
dtype=config.image_encoder_dtype,
|
|
460
424
|
)
|
|
461
425
|
image_encoder = WanImageEncoder.from_state_dict(
|
|
462
426
|
image_encoder_state_dict,
|
|
463
427
|
device=init_device,
|
|
464
|
-
dtype=
|
|
428
|
+
dtype=config.image_encoder_dtype,
|
|
465
429
|
)
|
|
466
430
|
|
|
467
431
|
# determine wan video model type by dit params
|
|
@@ -476,50 +440,47 @@ class WanVideoPipeline(BasePipeline):
|
|
|
476
440
|
model_type = "1.3b-t2v"
|
|
477
441
|
|
|
478
442
|
# shift for different model_type
|
|
479
|
-
shift = SHIFT_FACTORS[model_type] if shift is None else shift
|
|
443
|
+
config.shift = SHIFT_FACTORS[model_type] if config.shift is None else config.shift
|
|
480
444
|
|
|
481
445
|
with LoRAContext():
|
|
446
|
+
attn_kwargs = {
|
|
447
|
+
"attn_impl": config.dit_attn_impl,
|
|
448
|
+
"sparge_smooth_k": config.sparge_smooth_k,
|
|
449
|
+
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
450
|
+
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
451
|
+
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
452
|
+
}
|
|
482
453
|
dit = WanDiT.from_state_dict(
|
|
483
454
|
dit_state_dict,
|
|
484
455
|
model_type=model_type,
|
|
485
456
|
device=init_device,
|
|
486
|
-
dtype=
|
|
487
|
-
|
|
457
|
+
dtype=config.model_dtype,
|
|
458
|
+
attn_kwargs=attn_kwargs,
|
|
488
459
|
)
|
|
460
|
+
if config.use_fp8_linear:
|
|
461
|
+
enable_fp8_linear(dit)
|
|
489
462
|
|
|
490
463
|
pipe = cls(
|
|
491
|
-
config=
|
|
464
|
+
config=config,
|
|
492
465
|
tokenizer=tokenizer,
|
|
493
466
|
text_encoder=text_encoder,
|
|
494
467
|
dit=dit,
|
|
495
468
|
vae=vae,
|
|
496
469
|
image_encoder=image_encoder,
|
|
497
|
-
shift=shift,
|
|
498
|
-
batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
|
|
499
|
-
vae_tiled=vae_tiled,
|
|
500
|
-
vae_tile_size=vae_tile_size,
|
|
501
|
-
vae_tile_stride=vae_tile_stride,
|
|
502
|
-
device=device,
|
|
503
|
-
dtype=dtype,
|
|
504
470
|
)
|
|
505
471
|
pipe.eval()
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
cfg_degree = parallel_config["cfg_degree"]
|
|
512
|
-
sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
|
|
513
|
-
sp_ring_degree = parallel_config["sp_ring_degree"]
|
|
514
|
-
tp_degree = parallel_config["tp_degree"]
|
|
515
|
-
use_fsdp = parallel_config["use_fsdp"]
|
|
472
|
+
|
|
473
|
+
if config.offload_mode is not None:
|
|
474
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
475
|
+
|
|
476
|
+
if config.parallelism > 1:
|
|
516
477
|
return ParallelWrapper(
|
|
517
478
|
pipe,
|
|
518
|
-
cfg_degree=cfg_degree,
|
|
519
|
-
sp_ulysses_degree=sp_ulysses_degree,
|
|
520
|
-
sp_ring_degree=sp_ring_degree,
|
|
521
|
-
tp_degree=tp_degree,
|
|
522
|
-
use_fsdp=use_fsdp,
|
|
479
|
+
cfg_degree=config.cfg_degree,
|
|
480
|
+
sp_ulysses_degree=config.sp_ulysses_degree,
|
|
481
|
+
sp_ring_degree=config.sp_ring_degree,
|
|
482
|
+
tp_degree=config.tp_degree,
|
|
483
|
+
use_fsdp=config.use_fsdp,
|
|
523
484
|
device="cuda",
|
|
524
485
|
)
|
|
525
486
|
return pipe
|