diffsynth-engine 0.3.6.dev8__py3-none-any.whl → 0.3.6.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. diffsynth_engine/__init__.py +10 -8
  2. diffsynth_engine/configs/__init__.py +23 -0
  3. diffsynth_engine/configs/controlnet.py +17 -0
  4. diffsynth_engine/configs/pipeline.py +206 -0
  5. diffsynth_engine/models/basic/attention.py +43 -4
  6. diffsynth_engine/models/flux/flux_controlnet.py +8 -5
  7. diffsynth_engine/models/flux/flux_dit.py +22 -16
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/sd/sd_controlnet.py +2 -4
  11. diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
  12. diffsynth_engine/models/wan/wan_dit.py +15 -15
  13. diffsynth_engine/pipelines/__init__.py +5 -8
  14. diffsynth_engine/pipelines/base.py +14 -65
  15. diffsynth_engine/pipelines/flux_image.py +85 -158
  16. diffsynth_engine/pipelines/sd_image.py +30 -64
  17. diffsynth_engine/pipelines/sdxl_image.py +39 -71
  18. diffsynth_engine/pipelines/wan_video.py +66 -105
  19. diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
  20. diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
  21. diffsynth_engine/tools/flux_reference_tool.py +21 -5
  22. diffsynth_engine/tools/flux_replace_tool.py +15 -3
  23. diffsynth_engine/utils/fp8_linear.py +14 -5
  24. diffsynth_engine/utils/parallel.py +1 -1
  25. diffsynth_engine/utils/platform.py +9 -1
  26. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/METADATA +1 -1
  27. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/RECORD +30 -27
  28. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/WHEEL +0 -0
  29. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/licenses/LICENSE +0 -0
  30. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
1
1
  import re
2
- import os
3
2
  import torch
4
3
  import numpy as np
5
4
  from einops import repeat
6
- from dataclasses import dataclass
7
5
  from typing import Callable, Dict, Optional, List
8
6
  from tqdm import tqdm
9
7
  from PIL import Image, ImageOps
10
8
 
9
+ from diffsynth_engine.configs import SDPipelineConfig
11
10
  from diffsynth_engine.models.base import split_suffix
12
11
  from diffsynth_engine.models.basic.lora import LoRAContext
13
12
  from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
@@ -84,17 +83,6 @@ def convert_diffusers_name_to_compvis(key):
84
83
  return key
85
84
 
86
85
 
87
- @dataclass
88
- class SDModelConfig:
89
- unet_path: str | os.PathLike
90
- clip_path: Optional[str | os.PathLike] = None
91
- vae_path: Optional[str | os.PathLike] = None
92
-
93
- unet_dtype: torch.dtype = torch.float16
94
- clip_dtype: torch.dtype = torch.float16
95
- vae_dtype: torch.dtype = torch.float32
96
-
97
-
98
86
  class SDLoRAConverter(LoRAStateDictConverter):
99
87
  def _replace_kohya_te_key(self, key):
100
88
  key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
@@ -151,27 +139,22 @@ class SDImagePipeline(BasePipeline):
151
139
 
152
140
  def __init__(
153
141
  self,
154
- config: SDModelConfig,
142
+ config: SDPipelineConfig,
155
143
  tokenizer: CLIPTokenizer,
156
144
  text_encoder: SDTextEncoder,
157
145
  unet: SDUNet,
158
146
  vae_decoder: SDVAEDecoder,
159
147
  vae_encoder: SDVAEEncoder,
160
- batch_cfg: bool = True,
161
- vae_tiled: bool = False,
162
- vae_tile_size: int = 256,
163
- vae_tile_stride: int = 256,
164
- device: str = "cuda",
165
- dtype: torch.dtype = torch.float16,
166
148
  ):
167
149
  super().__init__(
168
- vae_tiled=vae_tiled,
169
- vae_tile_size=vae_tile_size,
170
- vae_tile_stride=vae_tile_stride,
171
- device=device,
172
- dtype=dtype,
150
+ vae_tiled=config.vae_tiled,
151
+ vae_tile_size=config.vae_tile_size,
152
+ vae_tile_stride=config.vae_tile_stride,
153
+ device=config.device,
154
+ dtype=config.model_dtype,
173
155
  )
174
156
  self.config = config
157
+ # sampler
175
158
  self.noise_scheduler = ScaledLinearScheduler()
176
159
  self.sampler = EulerSampler()
177
160
  # models
@@ -180,71 +163,54 @@ class SDImagePipeline(BasePipeline):
180
163
  self.unet = unet
181
164
  self.vae_decoder = vae_decoder
182
165
  self.vae_encoder = vae_encoder
183
- self.batch_cfg = batch_cfg
184
166
  self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
185
167
 
186
168
  @classmethod
187
- def from_pretrained(
188
- cls,
189
- model_path_or_config: str | os.PathLike | SDModelConfig,
190
- batch_cfg: bool = True,
191
- vae_tiled: bool = False,
192
- vae_tile_size: int = 256,
193
- vae_tile_stride: int = 256,
194
- device: str = "cuda",
195
- dtype: torch.dtype = torch.float16,
196
- offload_mode: str | None = None,
197
- ) -> "SDImagePipeline":
169
+ def from_pretrained(cls, model_path_or_config: SDPipelineConfig) -> "SDImagePipeline":
198
170
  if isinstance(model_path_or_config, str):
199
- model_config = SDModelConfig(unet_path=model_path_or_config)
171
+ config = SDPipelineConfig(model_path=model_path_or_config)
200
172
  else:
201
- model_config = model_path_or_config
173
+ config = model_path_or_config
202
174
 
203
- logger.info(f"loading state dict from {model_config.unet_path} ...")
204
- unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
175
+ logger.info(f"loading state dict from {config.model_path} ...")
176
+ unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
205
177
 
206
- if model_config.vae_path is not None:
207
- logger.info(f"loading state dict from {model_config.vae_path} ...")
208
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
178
+ if config.vae_path is not None:
179
+ logger.info(f"loading state dict from {config.vae_path} ...")
180
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
209
181
  else:
210
182
  vae_state_dict = unet_state_dict
211
183
 
212
- if model_config.clip_path is not None:
213
- logger.info(f"loading state dict from {model_config.clip_path} ...")
214
- clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
184
+ if config.clip_path is not None:
185
+ logger.info(f"loading state dict from {config.clip_path} ...")
186
+ clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
215
187
  else:
216
188
  clip_state_dict = unet_state_dict
217
189
 
218
- init_device = "cpu" if offload_mode else device
190
+ init_device = "cpu" if config.offload_mode is not None else config.device
219
191
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
220
192
  with LoRAContext():
221
- text_encoder = SDTextEncoder.from_state_dict(
222
- clip_state_dict, device=init_device, dtype=model_config.clip_dtype
223
- )
224
- unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
193
+ text_encoder = SDTextEncoder.from_state_dict(clip_state_dict, device=init_device, dtype=config.clip_dtype)
194
+ unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
225
195
  vae_decoder = SDVAEDecoder.from_state_dict(
226
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
196
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
227
197
  )
228
198
  vae_encoder = SDVAEEncoder.from_state_dict(
229
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
199
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
230
200
  )
231
201
 
232
202
  pipe = cls(
233
- config=model_config,
203
+ config=config,
234
204
  tokenizer=tokenizer,
235
205
  text_encoder=text_encoder,
236
206
  unet=unet,
237
207
  vae_decoder=vae_decoder,
238
208
  vae_encoder=vae_encoder,
239
- batch_cfg=batch_cfg,
240
- vae_tiled=vae_tiled,
241
- vae_tile_size=vae_tile_size,
242
- vae_tile_stride=vae_tile_stride,
243
- device=device,
244
- dtype=dtype,
245
209
  )
246
- if offload_mode is not None:
247
- pipe.enable_cpu_offload(offload_mode)
210
+ pipe.eval()
211
+
212
+ if config.offload_mode is not None:
213
+ pipe.enable_cpu_offload(config.offload_mode)
248
214
  return pipe
249
215
 
250
216
  @classmethod
@@ -439,7 +405,7 @@ class SDImagePipeline(BasePipeline):
439
405
  controlnet_params=controlnet_params,
440
406
  current_step=i,
441
407
  total_step=len(timesteps),
442
- batch_cfg=self.batch_cfg,
408
+ batch_cfg=self.config.batch_cfg,
443
409
  )
444
410
  # Denoise
445
411
  latents = self.sampler.step(latents, noise_pred, i)
@@ -1,4 +1,3 @@
1
- import os
2
1
  import re
3
2
  import torch
4
3
  import numpy as np
@@ -6,8 +5,8 @@ from einops import repeat
6
5
  from typing import Callable, Dict, Optional, List
7
6
  from tqdm import tqdm
8
7
  from PIL import Image, ImageOps
9
- from dataclasses import dataclass
10
8
 
9
+ from diffsynth_engine.configs import SDXLPipelineConfig
11
10
  from diffsynth_engine.models.base import split_suffix
12
11
  from diffsynth_engine.models.basic.lora import LoRAContext
13
12
  from diffsynth_engine.models.basic.timestep import TemporalTimesteps
@@ -102,25 +101,12 @@ class SDXLLoRAConverter(LoRAStateDictConverter):
102
101
  raise ValueError(f"Unsupported key: {key}")
103
102
 
104
103
 
105
- @dataclass
106
- class SDXLModelConfig:
107
- unet_path: str | os.PathLike
108
- clip_l_path: Optional[str | os.PathLike] = None
109
- clip_g_path: Optional[str | os.PathLike] = None
110
- vae_path: Optional[str | os.PathLike] = None
111
-
112
- unet_dtype: torch.dtype = torch.float16
113
- clip_l_dtype: torch.dtype = torch.float16
114
- clip_g_dtype: torch.dtype = torch.float16
115
- vae_dtype: torch.dtype = torch.float32
116
-
117
-
118
104
  class SDXLImagePipeline(BasePipeline):
119
105
  lora_converter = SDXLLoRAConverter()
120
106
 
121
107
  def __init__(
122
108
  self,
123
- config: SDXLModelConfig,
109
+ config: SDXLPipelineConfig,
124
110
  tokenizer: CLIPTokenizer,
125
111
  tokenizer_2: CLIPTokenizer,
126
112
  text_encoder: SDXLTextEncoder,
@@ -128,21 +114,16 @@ class SDXLImagePipeline(BasePipeline):
128
114
  unet: SDXLUNet,
129
115
  vae_decoder: SDXLVAEDecoder,
130
116
  vae_encoder: SDXLVAEEncoder,
131
- batch_cfg: bool = True,
132
- vae_tiled: bool = False,
133
- vae_tile_size: int = 256,
134
- vae_tile_stride: int = 256,
135
- device: str = "cuda",
136
- dtype: torch.dtype = torch.float16,
137
117
  ):
138
118
  super().__init__(
139
- vae_tiled=vae_tiled,
140
- vae_tile_size=vae_tile_size,
141
- vae_tile_stride=vae_tile_stride,
142
- device=device,
143
- dtype=dtype,
119
+ vae_tiled=config.vae_tiled,
120
+ vae_tile_size=config.vae_tile_size,
121
+ vae_tile_stride=config.vae_tile_stride,
122
+ device=config.device,
123
+ dtype=config.model_dtype,
144
124
  )
145
125
  self.config = config
126
+ # sampler
146
127
  self.noise_scheduler = ScaledLinearScheduler()
147
128
  self.sampler = EulerSampler()
148
129
  # models
@@ -154,71 +135,62 @@ class SDXLImagePipeline(BasePipeline):
154
135
  self.vae_decoder = vae_decoder
155
136
  self.vae_encoder = vae_encoder
156
137
  self.add_time_proj = TemporalTimesteps(
157
- num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
138
+ num_channels=256,
139
+ flip_sin_to_cos=True,
140
+ downscale_freq_shift=0,
141
+ device=config.device,
142
+ dtype=config.model_dtype,
158
143
  )
159
- self.batch_cfg = batch_cfg
160
144
  self.model_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
161
145
 
162
146
  @classmethod
163
- def from_pretrained(
164
- cls,
165
- model_path_or_config: str | os.PathLike | SDXLModelConfig,
166
- batch_cfg: bool = True,
167
- vae_tiled: bool = False,
168
- vae_tile_size: int = 256,
169
- vae_tile_stride: int = 256,
170
- device: str = "cuda",
171
- dtype: torch.dtype = torch.float16,
172
- offload_mode: str | None = None,
173
- ) -> "SDXLImagePipeline":
147
+ def from_pretrained(cls, model_path_or_config: SDXLPipelineConfig) -> "SDXLImagePipeline":
174
148
  if isinstance(model_path_or_config, str):
175
- model_config = SDXLModelConfig(
176
- unet_path=model_path_or_config, unet_dtype=dtype, clip_l_dtype=dtype, clip_g_dtype=dtype
177
- )
149
+ config = SDXLPipelineConfig(model_path=model_path_or_config)
178
150
  else:
179
- model_config = model_path_or_config
151
+ config = model_path_or_config
180
152
 
181
- logger.info(f"loading state dict from {model_config.unet_path} ...")
182
- unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
153
+ logger.info(f"loading state dict from {config.model_path} ...")
154
+ unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
183
155
 
184
- if model_config.vae_path is not None:
185
- logger.info(f"loading state dict from {model_config.vae_path} ...")
186
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
156
+ if config.vae_path is not None:
157
+ logger.info(f"loading state dict from {config.vae_path} ...")
158
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
187
159
  else:
188
160
  vae_state_dict = unet_state_dict
189
161
 
190
- if model_config.clip_l_path is not None:
191
- logger.info(f"loading state dict from {model_config.clip_l_path} ...")
192
- clip_l_state_dict = cls.load_model_checkpoint(model_config.clip_l_path, device="cpu", dtype=dtype)
162
+ if config.clip_l_path is not None:
163
+ logger.info(f"loading state dict from {config.clip_l_path} ...")
164
+ clip_l_state_dict = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype)
193
165
  else:
194
166
  clip_l_state_dict = unet_state_dict
195
167
 
196
- if model_config.clip_g_path is not None:
197
- logger.info(f"loading state dict from {model_config.clip_g_path} ...")
198
- clip_g_state_dict = cls.load_model_checkpoint(model_config.clip_g_path, device="cpu", dtype=dtype)
168
+ if config.clip_g_path is not None:
169
+ logger.info(f"loading state dict from {config.clip_g_path} ...")
170
+ clip_g_state_dict = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype)
199
171
  else:
200
172
  clip_g_state_dict = unet_state_dict
201
173
 
202
- init_device = "cpu" if offload_mode else device
174
+ init_device = "cpu" if config.offload_mode else config.device
203
175
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
204
176
  tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
205
177
  with LoRAContext():
206
178
  text_encoder = SDXLTextEncoder.from_state_dict(
207
- clip_l_state_dict, device=init_device, dtype=model_config.clip_l_dtype
179
+ clip_l_state_dict, device=init_device, dtype=config.clip_l_dtype
208
180
  )
209
181
  text_encoder_2 = SDXLTextEncoder2.from_state_dict(
210
- clip_g_state_dict, device=init_device, dtype=model_config.clip_g_dtype
182
+ clip_g_state_dict, device=init_device, dtype=config.clip_g_dtype
211
183
  )
212
- unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
184
+ unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
213
185
  vae_decoder = SDXLVAEDecoder.from_state_dict(
214
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
186
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
215
187
  )
216
188
  vae_encoder = SDXLVAEEncoder.from_state_dict(
217
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
189
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
218
190
  )
219
191
 
220
192
  pipe = cls(
221
- config=model_config,
193
+ config=config,
222
194
  tokenizer=tokenizer,
223
195
  tokenizer_2=tokenizer_2,
224
196
  text_encoder=text_encoder,
@@ -226,15 +198,11 @@ class SDXLImagePipeline(BasePipeline):
226
198
  unet=unet,
227
199
  vae_decoder=vae_decoder,
228
200
  vae_encoder=vae_encoder,
229
- batch_cfg=batch_cfg,
230
- vae_tiled=vae_tiled,
231
- vae_tile_size=vae_tile_size,
232
- vae_tile_stride=vae_tile_stride,
233
- device=device,
234
- dtype=dtype,
235
201
  )
236
- if offload_mode is not None:
237
- pipe.enable_cpu_offload(offload_mode)
202
+ pipe.eval()
203
+
204
+ if config.offload_mode is not None:
205
+ pipe.enable_cpu_offload(config.offload_mode)
238
206
  return pipe
239
207
 
240
208
  @classmethod
@@ -517,7 +485,7 @@ class SDXLImagePipeline(BasePipeline):
517
485
  controlnet_params=controlnet_params,
518
486
  current_step=i,
519
487
  total_step=len(timesteps),
520
- batch_cfg=self.batch_cfg,
488
+ batch_cfg=self.config.batch_cfg,
521
489
  )
522
490
  # Denoise
523
491
  latents = self.sampler.step(latents, noise_pred, i)
@@ -2,11 +2,11 @@ import torch
2
2
  import torch.distributed as dist
3
3
  import numpy as np
4
4
  from einops import rearrange
5
- from dataclasses import dataclass
6
5
  from typing import Callable, List, Tuple, Optional
7
6
  from tqdm import tqdm
8
7
  from PIL import Image
9
8
 
9
+ from diffsynth_engine.configs import WanPipelineConfig
10
10
  from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler
11
11
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
12
12
  from diffsynth_engine.models.wan.wan_dit import WanDiT
@@ -18,6 +18,7 @@ from diffsynth_engine.tokenizers import WanT5Tokenizer
18
18
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
19
19
  from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH
20
20
  from diffsynth_engine.utils.download import fetch_model
21
+ from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
21
22
  from diffsynth_engine.utils.parallel import ParallelWrapper
22
23
  from diffsynth_engine.utils import logging
23
24
 
@@ -25,26 +26,6 @@ from diffsynth_engine.utils import logging
25
26
  logger = logging.get_logger(__name__)
26
27
 
27
28
 
28
- @dataclass
29
- class WanModelConfig:
30
- model_path: Optional[str] = None
31
- vae_path: Optional[str] = None
32
- t5_path: Optional[str] = None
33
- image_encoder_path: Optional[str] = None
34
-
35
- vae_dtype: torch.dtype = torch.float32
36
- dit_dtype: torch.dtype = torch.bfloat16
37
- t5_dtype: torch.dtype = torch.bfloat16
38
- image_encoder_dtype: torch.dtype = torch.bfloat16
39
-
40
- dit_attn_impl: Optional[str] = "auto"
41
-
42
- sp_ulysses_degree: Optional[int] = None
43
- sp_ring_degree: Optional[int] = None
44
- tp_degree: Optional[int] = None
45
- use_fsdp: bool = False
46
-
47
-
48
29
  class WanLoRAConverter(LoRAStateDictConverter):
49
30
  def _from_diffsynth(self, state_dict):
50
31
  dit_dict = {}
@@ -129,42 +110,40 @@ class WanVideoPipeline(BasePipeline):
129
110
 
130
111
  def __init__(
131
112
  self,
132
- config: WanModelConfig,
113
+ config: WanPipelineConfig,
133
114
  tokenizer: WanT5Tokenizer,
134
115
  text_encoder: WanTextEncoder,
135
116
  dit: WanDiT,
136
117
  vae: WanVideoVAE,
137
118
  image_encoder: WanImageEncoder,
138
- shift: float = 5.0,
139
- batch_cfg: bool = False,
140
- vae_tiled: bool = True,
141
- vae_tile_size: Tuple[int, int] = (34, 34),
142
- vae_tile_stride: Tuple[int, int] = (18, 16),
143
- device="cuda",
144
- dtype=torch.bfloat16,
145
119
  ):
146
120
  super().__init__(
147
- vae_tiled=vae_tiled,
148
- vae_tile_size=vae_tile_size,
149
- vae_tile_stride=vae_tile_stride,
150
- device=device,
151
- dtype=dtype,
121
+ vae_tiled=config.vae_tiled,
122
+ vae_tile_size=config.vae_tile_size,
123
+ vae_tile_stride=config.vae_tile_stride,
124
+ device=config.device,
125
+ dtype=config.model_dtype,
152
126
  )
153
127
  self.config = config
154
- self.noise_scheduler = RecifitedFlowScheduler(shift=shift, sigma_min=0.001, sigma_max=0.999)
128
+ # sampler
129
+ self.noise_scheduler = RecifitedFlowScheduler(
130
+ shift=config.shift if config.shift is not None else 5.0,
131
+ sigma_min=0.001,
132
+ sigma_max=0.999,
133
+ )
155
134
  self.sampler = FlowMatchEulerSampler()
135
+ # models
156
136
  self.tokenizer = tokenizer
157
137
  self.text_encoder = text_encoder
158
138
  self.dit = dit
159
139
  self.vae = vae
160
140
  self.image_encoder = image_encoder
161
- self.batch_cfg = batch_cfg
162
141
  self.model_names = ["text_encoder", "dit", "vae", "image_encoder"]
163
142
 
164
143
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
165
- assert self.config.tp_degree is None, (
144
+ assert self.config.tp_degree is None or self.config.tp_degree == 1, (
166
145
  "load LoRA is not allowed when tensor parallel is enabled; "
167
- "set tp_degree=None during pipeline initialization"
146
+ "set tp_degree=None or tp_degree=1 during pipeline initialization"
168
147
  )
169
148
  assert not (self.config.use_fsdp and fused), (
170
149
  "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
@@ -228,7 +207,7 @@ class WanVideoPipeline(BasePipeline):
228
207
  tile_size=self.vae_tile_size,
229
208
  tile_stride=self.vae_tile_stride,
230
209
  )
231
- latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
210
+ latents = latents.to(dtype=self.config.model_dtype, device=self.device)
232
211
  return latents
233
212
 
234
213
  def decode_video(self, latents, progress_callback=None) -> List[torch.Tensor]:
@@ -241,7 +220,7 @@ class WanVideoPipeline(BasePipeline):
241
220
  tile_stride=self.vae_tile_stride,
242
221
  progress_callback=progress_callback,
243
222
  )
244
- videos = [video.to(dtype=self.config.dit_dtype, device=self.device) for video in videos]
223
+ videos = [video.to(dtype=self.config.model_dtype, device=self.device) for video in videos]
245
224
  return videos
246
225
 
247
226
  def predict_noise_with_cfg(
@@ -301,7 +280,7 @@ class WanVideoPipeline(BasePipeline):
301
280
  return noise_pred
302
281
 
303
282
  def predict_noise(self, latents, image_clip_feature, image_y, timestep, context):
304
- latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
283
+ latents = latents.to(dtype=self.config.model_dtype, device=self.device)
305
284
 
306
285
  noise_pred = self.dit(
307
286
  x=latents,
@@ -386,7 +365,7 @@ class WanVideoPipeline(BasePipeline):
386
365
  self.load_models_to_device(["dit"])
387
366
  hide_progress = dist.is_initialized() and dist.get_rank() != 0
388
367
  for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
389
- timestep = timestep.unsqueeze(0).to(dtype=self.config.dit_dtype, device=self.device)
368
+ timestep = timestep.unsqueeze(0).to(dtype=self.config.model_dtype, device=self.device)
390
369
  # Classifier-free guidance
391
370
  noise_pred = self.predict_noise_with_cfg(
392
371
  latents=latents,
@@ -396,7 +375,7 @@ class WanVideoPipeline(BasePipeline):
396
375
  image_clip_feature=image_clip_feature,
397
376
  image_y=image_y,
398
377
  cfg_scale=cfg_scale,
399
- batch_cfg=self.batch_cfg,
378
+ batch_cfg=self.config.batch_cfg,
400
379
  )
401
380
  # Scheduler
402
381
  latents = self.sampler.step(latents, noise_pred, i)
@@ -410,58 +389,43 @@ class WanVideoPipeline(BasePipeline):
410
389
  return frames
411
390
 
412
391
  @classmethod
413
- def from_pretrained(
414
- cls,
415
- model_path_or_config: str | WanModelConfig,
416
- shift: float | None = None,
417
- batch_cfg: bool = False,
418
- vae_tiled: bool = True,
419
- vae_tile_size: Tuple[int, int] = (34, 34),
420
- vae_tile_stride: Tuple[int, int] = (18, 16),
421
- device: str = "cuda",
422
- dtype: torch.dtype = torch.bfloat16,
423
- offload_mode: str | None = None,
424
- parallelism: int = 1,
425
- use_cfg_parallel: bool = False,
426
- ) -> "WanVideoPipeline":
392
+ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPipeline":
427
393
  if isinstance(model_path_or_config, str):
428
- model_config = WanModelConfig(model_path=model_path_or_config)
394
+ config = WanPipelineConfig(model_path=model_path_or_config)
429
395
  else:
430
- model_config = model_path_or_config
396
+ config = model_path_or_config
431
397
 
432
- if model_config.model_path is None:
433
- model_config.model_path = fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors")
434
- if model_config.t5_path is None:
435
- model_config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
436
- if model_config.vae_path is None:
437
- model_config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
398
+ if config.t5_path is None:
399
+ config.t5_path = fetch_model("muse/wan2.1-umt5", path="umt5.safetensors")
400
+ if config.vae_path is None:
401
+ config.vae_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
438
402
 
439
- logger.info(f"loading state dict from {model_config.model_path} ...")
440
- dit_state_dict = cls.load_model_checkpoint(model_config.model_path, device="cpu", dtype=model_config.dit_dtype)
403
+ logger.info(f"loading state dict from {config.model_path} ...")
404
+ dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
441
405
 
442
- logger.info(f"loading state dict from {model_config.t5_path} ...")
443
- t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
406
+ logger.info(f"loading state dict from {config.t5_path} ...")
407
+ t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
444
408
 
445
- logger.info(f"loading state dict from {model_config.vae_path} ...")
446
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
409
+ logger.info(f"loading state dict from {config.vae_path} ...")
410
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
447
411
 
448
- init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
412
+ init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
449
413
  tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=512, clean="whitespace")
450
- text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=model_config.t5_dtype)
451
- vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
414
+ text_encoder = WanTextEncoder.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
415
+ vae = WanVideoVAE.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
452
416
 
453
417
  image_encoder = None
454
- if model_config.image_encoder_path is not None:
455
- logger.info(f"loading state dict from {model_config.image_encoder_path} ...")
418
+ if config.image_encoder_path is not None:
419
+ logger.info(f"loading state dict from {config.image_encoder_path} ...")
456
420
  image_encoder_state_dict = cls.load_model_checkpoint(
457
- model_config.image_encoder_path,
421
+ config.image_encoder_path,
458
422
  device="cpu",
459
- dtype=model_config.image_encoder_dtype,
423
+ dtype=config.image_encoder_dtype,
460
424
  )
461
425
  image_encoder = WanImageEncoder.from_state_dict(
462
426
  image_encoder_state_dict,
463
427
  device=init_device,
464
- dtype=model_config.image_encoder_dtype,
428
+ dtype=config.image_encoder_dtype,
465
429
  )
466
430
 
467
431
  # determine wan video model type by dit params
@@ -476,50 +440,47 @@ class WanVideoPipeline(BasePipeline):
476
440
  model_type = "1.3b-t2v"
477
441
 
478
442
  # shift for different model_type
479
- shift = SHIFT_FACTORS[model_type] if shift is None else shift
443
+ config.shift = SHIFT_FACTORS[model_type] if config.shift is None else config.shift
480
444
 
481
445
  with LoRAContext():
446
+ attn_kwargs = {
447
+ "attn_impl": config.dit_attn_impl,
448
+ "sparge_smooth_k": config.sparge_smooth_k,
449
+ "sparge_cdfthreshd": config.sparge_cdfthreshd,
450
+ "sparge_simthreshd1": config.sparge_simthreshd1,
451
+ "sparge_pvthreshd": config.sparge_pvthreshd,
452
+ }
482
453
  dit = WanDiT.from_state_dict(
483
454
  dit_state_dict,
484
455
  model_type=model_type,
485
456
  device=init_device,
486
- dtype=model_config.dit_dtype,
487
- attn_impl=model_config.dit_attn_impl,
457
+ dtype=config.model_dtype,
458
+ attn_kwargs=attn_kwargs,
488
459
  )
460
+ if config.use_fp8_linear:
461
+ enable_fp8_linear(dit)
489
462
 
490
463
  pipe = cls(
491
- config=model_config,
464
+ config=config,
492
465
  tokenizer=tokenizer,
493
466
  text_encoder=text_encoder,
494
467
  dit=dit,
495
468
  vae=vae,
496
469
  image_encoder=image_encoder,
497
- shift=shift,
498
- batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
499
- vae_tiled=vae_tiled,
500
- vae_tile_size=vae_tile_size,
501
- vae_tile_stride=vae_tile_stride,
502
- device=device,
503
- dtype=dtype,
504
470
  )
505
471
  pipe.eval()
506
- if offload_mode is not None:
507
- pipe.enable_cpu_offload(offload_mode)
508
-
509
- if parallelism > 1:
510
- parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
511
- cfg_degree = parallel_config["cfg_degree"]
512
- sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
513
- sp_ring_degree = parallel_config["sp_ring_degree"]
514
- tp_degree = parallel_config["tp_degree"]
515
- use_fsdp = parallel_config["use_fsdp"]
472
+
473
+ if config.offload_mode is not None:
474
+ pipe.enable_cpu_offload(config.offload_mode)
475
+
476
+ if config.parallelism > 1:
516
477
  return ParallelWrapper(
517
478
  pipe,
518
- cfg_degree=cfg_degree,
519
- sp_ulysses_degree=sp_ulysses_degree,
520
- sp_ring_degree=sp_ring_degree,
521
- tp_degree=tp_degree,
522
- use_fsdp=use_fsdp,
479
+ cfg_degree=config.cfg_degree,
480
+ sp_ulysses_degree=config.sp_ulysses_degree,
481
+ sp_ring_degree=config.sp_ring_degree,
482
+ tp_degree=config.tp_degree,
483
+ use_fsdp=config.use_fsdp,
523
484
  device="cuda",
524
485
  )
525
486
  return pipe