diffsynth-engine 0.3.6.dev9__py3-none-any.whl → 0.3.6.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. diffsynth_engine/__init__.py +10 -8
  2. diffsynth_engine/configs/__init__.py +23 -0
  3. diffsynth_engine/configs/controlnet.py +17 -0
  4. diffsynth_engine/configs/pipeline.py +206 -0
  5. diffsynth_engine/models/basic/attention.py +43 -4
  6. diffsynth_engine/models/flux/flux_controlnet.py +8 -5
  7. diffsynth_engine/models/flux/flux_dit.py +22 -16
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +7 -7
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/sd/sd_controlnet.py +2 -4
  11. diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
  12. diffsynth_engine/models/wan/wan_dit.py +15 -15
  13. diffsynth_engine/pipelines/__init__.py +5 -8
  14. diffsynth_engine/pipelines/base.py +14 -65
  15. diffsynth_engine/pipelines/flux_image.py +85 -158
  16. diffsynth_engine/pipelines/sd_image.py +30 -64
  17. diffsynth_engine/pipelines/sdxl_image.py +39 -71
  18. diffsynth_engine/pipelines/wan_video.py +66 -105
  19. diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
  20. diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
  21. diffsynth_engine/tools/flux_reference_tool.py +21 -5
  22. diffsynth_engine/tools/flux_replace_tool.py +15 -3
  23. diffsynth_engine/utils/parallel.py +1 -1
  24. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/METADATA +1 -1
  25. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/RECORD +28 -25
  26. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/WHEEL +0 -0
  27. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/licenses/LICENSE +0 -0
  28. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from typing import Optional, Dict
2
+ from typing import Dict
3
3
  from diffsynth_engine.models.basic.unet_helper import (
4
4
  ResnetBlock,
5
5
  AttentionBlock,
@@ -180,7 +180,6 @@ class SDXLControlNetUnion(PreTrainedModel):
180
180
 
181
181
  def __init__(
182
182
  self,
183
- attn_impl: Optional[str] = None,
184
183
  device: str = "cuda:0",
185
184
  dtype: torch.dtype = torch.bfloat16,
186
185
  ):
@@ -2,7 +2,7 @@ import math
2
2
  import json
3
3
  import torch
4
4
  import torch.nn as nn
5
- from typing import Tuple, Optional
5
+ from typing import Any, Dict, Tuple, Optional
6
6
  from einops import rearrange
7
7
 
8
8
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
@@ -69,7 +69,7 @@ class SelfAttention(nn.Module):
69
69
  dim: int,
70
70
  num_heads: int,
71
71
  eps: float = 1e-6,
72
- attn_impl: Optional[str] = None,
72
+ attn_kwargs: Optional[Dict[str, Any]] = None,
73
73
  device: str = "cuda:0",
74
74
  dtype: torch.dtype = torch.bfloat16,
75
75
  ):
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
82
82
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
83
83
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
84
84
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
85
- self.attn_impl = attn_impl
85
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
86
86
 
87
87
  def forward(self, x, freqs):
88
88
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
@@ -94,7 +94,7 @@ class SelfAttention(nn.Module):
94
94
  q=rope_apply(q, freqs),
95
95
  k=rope_apply(k, freqs),
96
96
  v=v,
97
- attn_impl=self.attn_impl,
97
+ **self.attn_kwargs,
98
98
  )
99
99
  x = x.flatten(2)
100
100
  return self.o(x)
@@ -107,7 +107,7 @@ class CrossAttention(nn.Module):
107
107
  num_heads: int,
108
108
  eps: float = 1e-6,
109
109
  has_image_input: bool = False,
110
- attn_impl: Optional[str] = None,
110
+ attn_kwargs: Optional[Dict[str, Any]] = None,
111
111
  device: str = "cuda:0",
112
112
  dtype: torch.dtype = torch.bfloat16,
113
113
  ):
@@ -126,7 +126,7 @@ class CrossAttention(nn.Module):
126
126
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
127
127
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
128
128
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
129
- self.attn_impl = attn_impl
129
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
130
130
 
131
131
  def forward(self, x: torch.Tensor, y: torch.Tensor):
132
132
  if self.has_image_input:
@@ -140,12 +140,12 @@ class CrossAttention(nn.Module):
140
140
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
141
141
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
142
142
 
143
- x = attention_ops.attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
143
+ x = attention_ops.attention(q, k, v, **self.attn_kwargs).flatten(2)
144
144
  if self.has_image_input:
145
145
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
146
146
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
147
147
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
148
- y = attention_ops.attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
148
+ y = attention_ops.attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
149
149
  x = x + y
150
150
  return self.o(x)
151
151
 
@@ -158,7 +158,7 @@ class DiTBlock(nn.Module):
158
158
  num_heads: int,
159
159
  ffn_dim: int,
160
160
  eps: float = 1e-6,
161
- attn_impl: Optional[str] = None,
161
+ attn_kwargs: Optional[Dict[str, Any]] = None,
162
162
  device: str = "cuda:0",
163
163
  dtype: torch.dtype = torch.bfloat16,
164
164
  ):
@@ -166,9 +166,9 @@ class DiTBlock(nn.Module):
166
166
  self.dim = dim
167
167
  self.num_heads = num_heads
168
168
  self.ffn_dim = ffn_dim
169
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_impl=attn_impl, device=device, dtype=dtype)
169
+ self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
170
170
  self.cross_attn = CrossAttention(
171
- dim, num_heads, eps, has_image_input=has_image_input, attn_impl=attn_impl, device=device, dtype=dtype
171
+ dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
172
172
  )
173
173
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
174
174
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -265,7 +265,7 @@ class WanDiT(PreTrainedModel):
265
265
  num_layers: int,
266
266
  has_image_input: bool,
267
267
  flf_pos_emb: bool = False,
268
- attn_impl: Optional[str] = None,
268
+ attn_kwargs: Optional[Dict[str, Any]] = None,
269
269
  device: str = "cpu",
270
270
  dtype: torch.dtype = torch.bfloat16,
271
271
  ):
@@ -296,7 +296,7 @@ class WanDiT(PreTrainedModel):
296
296
  )
297
297
  self.blocks = nn.ModuleList(
298
298
  [
299
- DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_impl, device=device, dtype=dtype)
299
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
300
300
  for _ in range(num_layers)
301
301
  ]
302
302
  )
@@ -376,7 +376,7 @@ class WanDiT(PreTrainedModel):
376
376
  device,
377
377
  dtype,
378
378
  model_type="1.3b-t2v",
379
- attn_impl: Optional[str] = None,
379
+ attn_kwargs: Optional[Dict[str, Any]] = None,
380
380
  assign=True,
381
381
  ):
382
382
  if model_type == "1.3b-t2v":
@@ -390,7 +390,7 @@ class WanDiT(PreTrainedModel):
390
390
  else:
391
391
  raise ValueError(f"Unsupported model type: {model_type}")
392
392
  with no_init_weights():
393
- model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_impl=attn_impl)
393
+ model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
394
394
  model = model.requires_grad_(False)
395
395
  model.load_state_dict(state_dict, assign=assign)
396
396
  model.to(device=device, dtype=dtype)
@@ -1,20 +1,17 @@
1
1
  from .base import BasePipeline, LoRAStateDictConverter
2
2
  from .controlnet_helper import ControlNetParams
3
- from .flux_image import FluxImagePipeline, FluxModelConfig
4
- from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
5
- from .sd_image import SDImagePipeline, SDModelConfig
6
- from .wan_video import WanVideoPipeline, WanModelConfig
3
+ from .flux_image import FluxImagePipeline
4
+ from .sdxl_image import SDXLImagePipeline
5
+ from .sd_image import SDImagePipeline
6
+ from .wan_video import WanVideoPipeline
7
+
7
8
 
8
9
  __all__ = [
9
10
  "BasePipeline",
10
11
  "LoRAStateDictConverter",
11
12
  "FluxImagePipeline",
12
- "FluxModelConfig",
13
13
  "SDXLImagePipeline",
14
- "SDXLModelConfig",
15
14
  "SDImagePipeline",
16
- "SDModelConfig",
17
15
  "WanVideoPipeline",
18
- "WanModelConfig",
19
16
  "ControlNetParams",
20
17
  ]
@@ -3,7 +3,7 @@ import torch
3
3
  import numpy as np
4
4
  from typing import Dict, List, Tuple
5
5
  from PIL import Image
6
- from dataclasses import dataclass
6
+ from diffsynth_engine.configs import BaseConfig
7
7
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
8
8
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
9
9
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -14,11 +14,6 @@ from diffsynth_engine.utils.platform import empty_cache
14
14
  logger = logging.get_logger(__name__)
15
15
 
16
16
 
17
- @dataclass
18
- class ModelConfig:
19
- pass
20
-
21
-
22
17
  class LoRAStateDictConverter:
23
18
  def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
24
19
  return {"lora": lora_state_dict}
@@ -30,8 +25,8 @@ class BasePipeline:
30
25
  def __init__(
31
26
  self,
32
27
  vae_tiled: bool = False,
33
- vae_tile_size: int = -1,
34
- vae_tile_stride: int = -1,
28
+ vae_tile_size: int | Tuple[int, int] = -1,
29
+ vae_tile_stride: int | Tuple[int, int] = -1,
35
30
  device="cuda",
36
31
  dtype=torch.float16,
37
32
  ):
@@ -46,13 +41,7 @@ class BasePipeline:
46
41
  self._models_offload_params = {}
47
42
 
48
43
  @classmethod
49
- def from_pretrained(
50
- cls,
51
- model_path_or_config: str | os.PathLike | ModelConfig,
52
- device: str = "cuda",
53
- dtype: torch.dtype = torch.float16,
54
- offload_mode: str | None = None,
55
- ) -> "BasePipeline":
44
+ def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
56
45
  raise NotImplementedError()
57
46
 
58
47
  @classmethod
@@ -224,54 +213,6 @@ class BasePipeline:
224
213
  model.eval()
225
214
  return self
226
215
 
227
- @staticmethod
228
- def init_parallel_config(
229
- parallelism: int,
230
- use_cfg_parallel: bool,
231
- model_config: ModelConfig,
232
- ):
233
- assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
234
- cfg_degree = 2 if use_cfg_parallel else 1
235
- sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
236
- sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
237
- tp_degree = getattr(model_config, "tp_degree", None)
238
- use_fsdp = getattr(model_config, "use_fsdp", False)
239
-
240
- if tp_degree is not None:
241
- assert sp_ulysses_degree is None and sp_ring_degree is None, (
242
- "not allowed to enable sequence parallel and tensor parallel together; "
243
- "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
244
- )
245
- assert use_fsdp is False, (
246
- "not allowed to enable fully sharded data parallel and tensor parallel together; "
247
- "either set use_fsdp=False or set tp_degree=None during pipeline initialization"
248
- )
249
- assert parallelism == cfg_degree * tp_degree, (
250
- f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
251
- )
252
- sp_ulysses_degree = 1
253
- sp_ring_degree = 1
254
- elif sp_ulysses_degree is None and sp_ring_degree is None:
255
- # use ulysses if not specified
256
- sp_ulysses_degree = parallelism // cfg_degree
257
- sp_ring_degree = 1
258
- tp_degree = 1
259
- elif sp_ulysses_degree is not None and sp_ring_degree is not None:
260
- assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
261
- f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
262
- f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
263
- )
264
- tp_degree = 1
265
- else:
266
- raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
267
- return {
268
- "cfg_degree": cfg_degree,
269
- "sp_ulysses_degree": sp_ulysses_degree,
270
- "sp_ring_degree": sp_ring_degree,
271
- "tp_degree": tp_degree,
272
- "use_fsdp": use_fsdp,
273
- }
274
-
275
216
  def enable_cpu_offload(self, offload_mode: str):
276
217
  valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
277
218
  if offload_mode not in valid_offload_mode:
@@ -326,14 +267,22 @@ class BasePipeline:
326
267
  for model_name in self.model_names:
327
268
  if model_name not in load_model_names:
328
269
  model = getattr(self, model_name)
329
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device("cpu"):
270
+ if (
271
+ model is not None
272
+ and (p := next(model.parameters(), None)) is not None
273
+ and p.device != torch.device("cpu")
274
+ ):
330
275
  param_cache = self._models_offload_params[model_name]
331
276
  for name, param in model.named_parameters(recurse=True):
332
277
  param.data = param_cache[name]
333
278
  # load the needed models to device
334
279
  for model_name in load_model_names:
335
280
  model = getattr(self, model_name)
336
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device(self.device):
281
+ if (
282
+ model is not None
283
+ and (p := next(model.parameters(), None)) is not None
284
+ and p.device != torch.device(self.device)
285
+ ):
337
286
  model.to(self.device)
338
287
  # fresh the cuda cache
339
288
  empty_cache()