diffsynth-engine 0.3.6.dev9__py3-none-any.whl → 0.3.6.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +10 -8
- diffsynth_engine/configs/__init__.py +23 -0
- diffsynth_engine/configs/controlnet.py +17 -0
- diffsynth_engine/configs/pipeline.py +206 -0
- diffsynth_engine/models/basic/attention.py +43 -4
- diffsynth_engine/models/flux/flux_controlnet.py +8 -5
- diffsynth_engine/models/flux/flux_dit.py +22 -16
- diffsynth_engine/models/flux/flux_dit_fbcache.py +7 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/sd/sd_controlnet.py +2 -4
- diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
- diffsynth_engine/models/wan/wan_dit.py +15 -15
- diffsynth_engine/pipelines/__init__.py +5 -8
- diffsynth_engine/pipelines/base.py +14 -65
- diffsynth_engine/pipelines/flux_image.py +85 -158
- diffsynth_engine/pipelines/sd_image.py +30 -64
- diffsynth_engine/pipelines/sdxl_image.py +39 -71
- diffsynth_engine/pipelines/wan_video.py +66 -105
- diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_reference_tool.py +21 -5
- diffsynth_engine/tools/flux_replace_tool.py +15 -3
- diffsynth_engine/utils/parallel.py +1 -1
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/RECORD +28 -25
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Dict
|
|
3
3
|
from diffsynth_engine.models.basic.unet_helper import (
|
|
4
4
|
ResnetBlock,
|
|
5
5
|
AttentionBlock,
|
|
@@ -180,7 +180,6 @@ class SDXLControlNetUnion(PreTrainedModel):
|
|
|
180
180
|
|
|
181
181
|
def __init__(
|
|
182
182
|
self,
|
|
183
|
-
attn_impl: Optional[str] = None,
|
|
184
183
|
device: str = "cuda:0",
|
|
185
184
|
dtype: torch.dtype = torch.bfloat16,
|
|
186
185
|
):
|
|
@@ -2,7 +2,7 @@ import math
|
|
|
2
2
|
import json
|
|
3
3
|
import torch
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
from typing import Tuple, Optional
|
|
5
|
+
from typing import Any, Dict, Tuple, Optional
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
@@ -69,7 +69,7 @@ class SelfAttention(nn.Module):
|
|
|
69
69
|
dim: int,
|
|
70
70
|
num_heads: int,
|
|
71
71
|
eps: float = 1e-6,
|
|
72
|
-
|
|
72
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
73
73
|
device: str = "cuda:0",
|
|
74
74
|
dtype: torch.dtype = torch.bfloat16,
|
|
75
75
|
):
|
|
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
|
|
|
82
82
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
83
83
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
84
84
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
85
|
-
self.
|
|
85
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
86
86
|
|
|
87
87
|
def forward(self, x, freqs):
|
|
88
88
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
@@ -94,7 +94,7 @@ class SelfAttention(nn.Module):
|
|
|
94
94
|
q=rope_apply(q, freqs),
|
|
95
95
|
k=rope_apply(k, freqs),
|
|
96
96
|
v=v,
|
|
97
|
-
|
|
97
|
+
**self.attn_kwargs,
|
|
98
98
|
)
|
|
99
99
|
x = x.flatten(2)
|
|
100
100
|
return self.o(x)
|
|
@@ -107,7 +107,7 @@ class CrossAttention(nn.Module):
|
|
|
107
107
|
num_heads: int,
|
|
108
108
|
eps: float = 1e-6,
|
|
109
109
|
has_image_input: bool = False,
|
|
110
|
-
|
|
110
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
111
111
|
device: str = "cuda:0",
|
|
112
112
|
dtype: torch.dtype = torch.bfloat16,
|
|
113
113
|
):
|
|
@@ -126,7 +126,7 @@ class CrossAttention(nn.Module):
|
|
|
126
126
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
127
127
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
128
128
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
129
|
-
self.
|
|
129
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
130
130
|
|
|
131
131
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
132
132
|
if self.has_image_input:
|
|
@@ -140,12 +140,12 @@ class CrossAttention(nn.Module):
|
|
|
140
140
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
141
141
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
142
142
|
|
|
143
|
-
x = attention_ops.attention(q, k, v,
|
|
143
|
+
x = attention_ops.attention(q, k, v, **self.attn_kwargs).flatten(2)
|
|
144
144
|
if self.has_image_input:
|
|
145
145
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
146
146
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
147
147
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
148
|
-
y = attention_ops.attention(q, k_img, v_img,
|
|
148
|
+
y = attention_ops.attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
|
|
149
149
|
x = x + y
|
|
150
150
|
return self.o(x)
|
|
151
151
|
|
|
@@ -158,7 +158,7 @@ class DiTBlock(nn.Module):
|
|
|
158
158
|
num_heads: int,
|
|
159
159
|
ffn_dim: int,
|
|
160
160
|
eps: float = 1e-6,
|
|
161
|
-
|
|
161
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
162
162
|
device: str = "cuda:0",
|
|
163
163
|
dtype: torch.dtype = torch.bfloat16,
|
|
164
164
|
):
|
|
@@ -166,9 +166,9 @@ class DiTBlock(nn.Module):
|
|
|
166
166
|
self.dim = dim
|
|
167
167
|
self.num_heads = num_heads
|
|
168
168
|
self.ffn_dim = ffn_dim
|
|
169
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
169
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
170
170
|
self.cross_attn = CrossAttention(
|
|
171
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
171
|
+
dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
172
172
|
)
|
|
173
173
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
174
174
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -265,7 +265,7 @@ class WanDiT(PreTrainedModel):
|
|
|
265
265
|
num_layers: int,
|
|
266
266
|
has_image_input: bool,
|
|
267
267
|
flf_pos_emb: bool = False,
|
|
268
|
-
|
|
268
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
269
269
|
device: str = "cpu",
|
|
270
270
|
dtype: torch.dtype = torch.bfloat16,
|
|
271
271
|
):
|
|
@@ -296,7 +296,7 @@ class WanDiT(PreTrainedModel):
|
|
|
296
296
|
)
|
|
297
297
|
self.blocks = nn.ModuleList(
|
|
298
298
|
[
|
|
299
|
-
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps,
|
|
299
|
+
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
|
|
300
300
|
for _ in range(num_layers)
|
|
301
301
|
]
|
|
302
302
|
)
|
|
@@ -376,7 +376,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
376
|
device,
|
|
377
377
|
dtype,
|
|
378
378
|
model_type="1.3b-t2v",
|
|
379
|
-
|
|
379
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
380
380
|
assign=True,
|
|
381
381
|
):
|
|
382
382
|
if model_type == "1.3b-t2v":
|
|
@@ -390,7 +390,7 @@ class WanDiT(PreTrainedModel):
|
|
|
390
390
|
else:
|
|
391
391
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
392
392
|
with no_init_weights():
|
|
393
|
-
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype,
|
|
393
|
+
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
|
|
394
394
|
model = model.requires_grad_(False)
|
|
395
395
|
model.load_state_dict(state_dict, assign=assign)
|
|
396
396
|
model.to(device=device, dtype=dtype)
|
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
from .base import BasePipeline, LoRAStateDictConverter
|
|
2
2
|
from .controlnet_helper import ControlNetParams
|
|
3
|
-
from .flux_image import FluxImagePipeline
|
|
4
|
-
from .sdxl_image import SDXLImagePipeline
|
|
5
|
-
from .sd_image import SDImagePipeline
|
|
6
|
-
from .wan_video import WanVideoPipeline
|
|
3
|
+
from .flux_image import FluxImagePipeline
|
|
4
|
+
from .sdxl_image import SDXLImagePipeline
|
|
5
|
+
from .sd_image import SDImagePipeline
|
|
6
|
+
from .wan_video import WanVideoPipeline
|
|
7
|
+
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"BasePipeline",
|
|
10
11
|
"LoRAStateDictConverter",
|
|
11
12
|
"FluxImagePipeline",
|
|
12
|
-
"FluxModelConfig",
|
|
13
13
|
"SDXLImagePipeline",
|
|
14
|
-
"SDXLModelConfig",
|
|
15
14
|
"SDImagePipeline",
|
|
16
|
-
"SDModelConfig",
|
|
17
15
|
"WanVideoPipeline",
|
|
18
|
-
"WanModelConfig",
|
|
19
16
|
"ControlNetParams",
|
|
20
17
|
]
|
|
@@ -3,7 +3,7 @@ import torch
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from typing import Dict, List, Tuple
|
|
5
5
|
from PIL import Image
|
|
6
|
-
from
|
|
6
|
+
from diffsynth_engine.configs import BaseConfig
|
|
7
7
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
|
|
8
8
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
|
|
9
9
|
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
|
|
@@ -14,11 +14,6 @@ from diffsynth_engine.utils.platform import empty_cache
|
|
|
14
14
|
logger = logging.get_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@dataclass
|
|
18
|
-
class ModelConfig:
|
|
19
|
-
pass
|
|
20
|
-
|
|
21
|
-
|
|
22
17
|
class LoRAStateDictConverter:
|
|
23
18
|
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
24
19
|
return {"lora": lora_state_dict}
|
|
@@ -30,8 +25,8 @@ class BasePipeline:
|
|
|
30
25
|
def __init__(
|
|
31
26
|
self,
|
|
32
27
|
vae_tiled: bool = False,
|
|
33
|
-
vae_tile_size: int = -1,
|
|
34
|
-
vae_tile_stride: int = -1,
|
|
28
|
+
vae_tile_size: int | Tuple[int, int] = -1,
|
|
29
|
+
vae_tile_stride: int | Tuple[int, int] = -1,
|
|
35
30
|
device="cuda",
|
|
36
31
|
dtype=torch.float16,
|
|
37
32
|
):
|
|
@@ -46,13 +41,7 @@ class BasePipeline:
|
|
|
46
41
|
self._models_offload_params = {}
|
|
47
42
|
|
|
48
43
|
@classmethod
|
|
49
|
-
def from_pretrained(
|
|
50
|
-
cls,
|
|
51
|
-
model_path_or_config: str | os.PathLike | ModelConfig,
|
|
52
|
-
device: str = "cuda",
|
|
53
|
-
dtype: torch.dtype = torch.float16,
|
|
54
|
-
offload_mode: str | None = None,
|
|
55
|
-
) -> "BasePipeline":
|
|
44
|
+
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
|
|
56
45
|
raise NotImplementedError()
|
|
57
46
|
|
|
58
47
|
@classmethod
|
|
@@ -224,54 +213,6 @@ class BasePipeline:
|
|
|
224
213
|
model.eval()
|
|
225
214
|
return self
|
|
226
215
|
|
|
227
|
-
@staticmethod
|
|
228
|
-
def init_parallel_config(
|
|
229
|
-
parallelism: int,
|
|
230
|
-
use_cfg_parallel: bool,
|
|
231
|
-
model_config: ModelConfig,
|
|
232
|
-
):
|
|
233
|
-
assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
|
|
234
|
-
cfg_degree = 2 if use_cfg_parallel else 1
|
|
235
|
-
sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
|
|
236
|
-
sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
|
|
237
|
-
tp_degree = getattr(model_config, "tp_degree", None)
|
|
238
|
-
use_fsdp = getattr(model_config, "use_fsdp", False)
|
|
239
|
-
|
|
240
|
-
if tp_degree is not None:
|
|
241
|
-
assert sp_ulysses_degree is None and sp_ring_degree is None, (
|
|
242
|
-
"not allowed to enable sequence parallel and tensor parallel together; "
|
|
243
|
-
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
|
|
244
|
-
)
|
|
245
|
-
assert use_fsdp is False, (
|
|
246
|
-
"not allowed to enable fully sharded data parallel and tensor parallel together; "
|
|
247
|
-
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
|
|
248
|
-
)
|
|
249
|
-
assert parallelism == cfg_degree * tp_degree, (
|
|
250
|
-
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
|
|
251
|
-
)
|
|
252
|
-
sp_ulysses_degree = 1
|
|
253
|
-
sp_ring_degree = 1
|
|
254
|
-
elif sp_ulysses_degree is None and sp_ring_degree is None:
|
|
255
|
-
# use ulysses if not specified
|
|
256
|
-
sp_ulysses_degree = parallelism // cfg_degree
|
|
257
|
-
sp_ring_degree = 1
|
|
258
|
-
tp_degree = 1
|
|
259
|
-
elif sp_ulysses_degree is not None and sp_ring_degree is not None:
|
|
260
|
-
assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
|
|
261
|
-
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
|
|
262
|
-
f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
|
|
263
|
-
)
|
|
264
|
-
tp_degree = 1
|
|
265
|
-
else:
|
|
266
|
-
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
267
|
-
return {
|
|
268
|
-
"cfg_degree": cfg_degree,
|
|
269
|
-
"sp_ulysses_degree": sp_ulysses_degree,
|
|
270
|
-
"sp_ring_degree": sp_ring_degree,
|
|
271
|
-
"tp_degree": tp_degree,
|
|
272
|
-
"use_fsdp": use_fsdp,
|
|
273
|
-
}
|
|
274
|
-
|
|
275
216
|
def enable_cpu_offload(self, offload_mode: str):
|
|
276
217
|
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
|
|
277
218
|
if offload_mode not in valid_offload_mode:
|
|
@@ -326,14 +267,22 @@ class BasePipeline:
|
|
|
326
267
|
for model_name in self.model_names:
|
|
327
268
|
if model_name not in load_model_names:
|
|
328
269
|
model = getattr(self, model_name)
|
|
329
|
-
if
|
|
270
|
+
if (
|
|
271
|
+
model is not None
|
|
272
|
+
and (p := next(model.parameters(), None)) is not None
|
|
273
|
+
and p.device != torch.device("cpu")
|
|
274
|
+
):
|
|
330
275
|
param_cache = self._models_offload_params[model_name]
|
|
331
276
|
for name, param in model.named_parameters(recurse=True):
|
|
332
277
|
param.data = param_cache[name]
|
|
333
278
|
# load the needed models to device
|
|
334
279
|
for model_name in load_model_names:
|
|
335
280
|
model = getattr(self, model_name)
|
|
336
|
-
if
|
|
281
|
+
if (
|
|
282
|
+
model is not None
|
|
283
|
+
and (p := next(model.parameters(), None)) is not None
|
|
284
|
+
and p.device != torch.device(self.device)
|
|
285
|
+
):
|
|
337
286
|
model.to(self.device)
|
|
338
287
|
# fresh the cuda cache
|
|
339
288
|
empty_cache()
|