diffsynth-engine 0.3.6.dev8__py3-none-any.whl → 0.3.6.dev10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +10 -8
- diffsynth_engine/configs/__init__.py +23 -0
- diffsynth_engine/configs/controlnet.py +17 -0
- diffsynth_engine/configs/pipeline.py +206 -0
- diffsynth_engine/models/basic/attention.py +43 -4
- diffsynth_engine/models/flux/flux_controlnet.py +8 -5
- diffsynth_engine/models/flux/flux_dit.py +22 -16
- diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/sd/sd_controlnet.py +2 -4
- diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
- diffsynth_engine/models/wan/wan_dit.py +15 -15
- diffsynth_engine/pipelines/__init__.py +5 -8
- diffsynth_engine/pipelines/base.py +14 -65
- diffsynth_engine/pipelines/flux_image.py +85 -158
- diffsynth_engine/pipelines/sd_image.py +30 -64
- diffsynth_engine/pipelines/sdxl_image.py +39 -71
- diffsynth_engine/pipelines/wan_video.py +66 -105
- diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_reference_tool.py +21 -5
- diffsynth_engine/tools/flux_replace_tool.py +15 -3
- diffsynth_engine/utils/fp8_linear.py +14 -5
- diffsynth_engine/utils/parallel.py +1 -1
- diffsynth_engine/utils/platform.py +9 -1
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/RECORD +30 -27
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import math
|
|
|
2
2
|
import json
|
|
3
3
|
import torch
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
from typing import Tuple, Optional
|
|
5
|
+
from typing import Any, Dict, Tuple, Optional
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
@@ -69,7 +69,7 @@ class SelfAttention(nn.Module):
|
|
|
69
69
|
dim: int,
|
|
70
70
|
num_heads: int,
|
|
71
71
|
eps: float = 1e-6,
|
|
72
|
-
|
|
72
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
73
73
|
device: str = "cuda:0",
|
|
74
74
|
dtype: torch.dtype = torch.bfloat16,
|
|
75
75
|
):
|
|
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
|
|
|
82
82
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
83
83
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
84
84
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
85
|
-
self.
|
|
85
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
86
86
|
|
|
87
87
|
def forward(self, x, freqs):
|
|
88
88
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
@@ -94,7 +94,7 @@ class SelfAttention(nn.Module):
|
|
|
94
94
|
q=rope_apply(q, freqs),
|
|
95
95
|
k=rope_apply(k, freqs),
|
|
96
96
|
v=v,
|
|
97
|
-
|
|
97
|
+
**self.attn_kwargs,
|
|
98
98
|
)
|
|
99
99
|
x = x.flatten(2)
|
|
100
100
|
return self.o(x)
|
|
@@ -107,7 +107,7 @@ class CrossAttention(nn.Module):
|
|
|
107
107
|
num_heads: int,
|
|
108
108
|
eps: float = 1e-6,
|
|
109
109
|
has_image_input: bool = False,
|
|
110
|
-
|
|
110
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
111
111
|
device: str = "cuda:0",
|
|
112
112
|
dtype: torch.dtype = torch.bfloat16,
|
|
113
113
|
):
|
|
@@ -126,7 +126,7 @@ class CrossAttention(nn.Module):
|
|
|
126
126
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
127
127
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
128
128
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
129
|
-
self.
|
|
129
|
+
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
130
130
|
|
|
131
131
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
132
132
|
if self.has_image_input:
|
|
@@ -140,12 +140,12 @@ class CrossAttention(nn.Module):
|
|
|
140
140
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
141
141
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
142
142
|
|
|
143
|
-
x = attention_ops.attention(q, k, v,
|
|
143
|
+
x = attention_ops.attention(q, k, v, **self.attn_kwargs).flatten(2)
|
|
144
144
|
if self.has_image_input:
|
|
145
145
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
146
146
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
147
147
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
148
|
-
y = attention_ops.attention(q, k_img, v_img,
|
|
148
|
+
y = attention_ops.attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
|
|
149
149
|
x = x + y
|
|
150
150
|
return self.o(x)
|
|
151
151
|
|
|
@@ -158,7 +158,7 @@ class DiTBlock(nn.Module):
|
|
|
158
158
|
num_heads: int,
|
|
159
159
|
ffn_dim: int,
|
|
160
160
|
eps: float = 1e-6,
|
|
161
|
-
|
|
161
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
162
162
|
device: str = "cuda:0",
|
|
163
163
|
dtype: torch.dtype = torch.bfloat16,
|
|
164
164
|
):
|
|
@@ -166,9 +166,9 @@ class DiTBlock(nn.Module):
|
|
|
166
166
|
self.dim = dim
|
|
167
167
|
self.num_heads = num_heads
|
|
168
168
|
self.ffn_dim = ffn_dim
|
|
169
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
169
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
170
170
|
self.cross_attn = CrossAttention(
|
|
171
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
171
|
+
dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
172
172
|
)
|
|
173
173
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
174
174
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -265,7 +265,7 @@ class WanDiT(PreTrainedModel):
|
|
|
265
265
|
num_layers: int,
|
|
266
266
|
has_image_input: bool,
|
|
267
267
|
flf_pos_emb: bool = False,
|
|
268
|
-
|
|
268
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
269
269
|
device: str = "cpu",
|
|
270
270
|
dtype: torch.dtype = torch.bfloat16,
|
|
271
271
|
):
|
|
@@ -296,7 +296,7 @@ class WanDiT(PreTrainedModel):
|
|
|
296
296
|
)
|
|
297
297
|
self.blocks = nn.ModuleList(
|
|
298
298
|
[
|
|
299
|
-
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps,
|
|
299
|
+
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
|
|
300
300
|
for _ in range(num_layers)
|
|
301
301
|
]
|
|
302
302
|
)
|
|
@@ -376,7 +376,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
376
|
device,
|
|
377
377
|
dtype,
|
|
378
378
|
model_type="1.3b-t2v",
|
|
379
|
-
|
|
379
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
380
380
|
assign=True,
|
|
381
381
|
):
|
|
382
382
|
if model_type == "1.3b-t2v":
|
|
@@ -390,7 +390,7 @@ class WanDiT(PreTrainedModel):
|
|
|
390
390
|
else:
|
|
391
391
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
392
392
|
with no_init_weights():
|
|
393
|
-
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype,
|
|
393
|
+
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
|
|
394
394
|
model = model.requires_grad_(False)
|
|
395
395
|
model.load_state_dict(state_dict, assign=assign)
|
|
396
396
|
model.to(device=device, dtype=dtype)
|
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
from .base import BasePipeline, LoRAStateDictConverter
|
|
2
2
|
from .controlnet_helper import ControlNetParams
|
|
3
|
-
from .flux_image import FluxImagePipeline
|
|
4
|
-
from .sdxl_image import SDXLImagePipeline
|
|
5
|
-
from .sd_image import SDImagePipeline
|
|
6
|
-
from .wan_video import WanVideoPipeline
|
|
3
|
+
from .flux_image import FluxImagePipeline
|
|
4
|
+
from .sdxl_image import SDXLImagePipeline
|
|
5
|
+
from .sd_image import SDImagePipeline
|
|
6
|
+
from .wan_video import WanVideoPipeline
|
|
7
|
+
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"BasePipeline",
|
|
10
11
|
"LoRAStateDictConverter",
|
|
11
12
|
"FluxImagePipeline",
|
|
12
|
-
"FluxModelConfig",
|
|
13
13
|
"SDXLImagePipeline",
|
|
14
|
-
"SDXLModelConfig",
|
|
15
14
|
"SDImagePipeline",
|
|
16
|
-
"SDModelConfig",
|
|
17
15
|
"WanVideoPipeline",
|
|
18
|
-
"WanModelConfig",
|
|
19
16
|
"ControlNetParams",
|
|
20
17
|
]
|
|
@@ -3,7 +3,7 @@ import torch
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from typing import Dict, List, Tuple
|
|
5
5
|
from PIL import Image
|
|
6
|
-
from
|
|
6
|
+
from diffsynth_engine.configs import BaseConfig
|
|
7
7
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
|
|
8
8
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
|
|
9
9
|
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
|
|
@@ -14,11 +14,6 @@ from diffsynth_engine.utils.platform import empty_cache
|
|
|
14
14
|
logger = logging.get_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@dataclass
|
|
18
|
-
class ModelConfig:
|
|
19
|
-
pass
|
|
20
|
-
|
|
21
|
-
|
|
22
17
|
class LoRAStateDictConverter:
|
|
23
18
|
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
24
19
|
return {"lora": lora_state_dict}
|
|
@@ -30,8 +25,8 @@ class BasePipeline:
|
|
|
30
25
|
def __init__(
|
|
31
26
|
self,
|
|
32
27
|
vae_tiled: bool = False,
|
|
33
|
-
vae_tile_size: int = -1,
|
|
34
|
-
vae_tile_stride: int = -1,
|
|
28
|
+
vae_tile_size: int | Tuple[int, int] = -1,
|
|
29
|
+
vae_tile_stride: int | Tuple[int, int] = -1,
|
|
35
30
|
device="cuda",
|
|
36
31
|
dtype=torch.float16,
|
|
37
32
|
):
|
|
@@ -46,13 +41,7 @@ class BasePipeline:
|
|
|
46
41
|
self._models_offload_params = {}
|
|
47
42
|
|
|
48
43
|
@classmethod
|
|
49
|
-
def from_pretrained(
|
|
50
|
-
cls,
|
|
51
|
-
model_path_or_config: str | os.PathLike | ModelConfig,
|
|
52
|
-
device: str = "cuda",
|
|
53
|
-
dtype: torch.dtype = torch.float16,
|
|
54
|
-
offload_mode: str | None = None,
|
|
55
|
-
) -> "BasePipeline":
|
|
44
|
+
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
|
|
56
45
|
raise NotImplementedError()
|
|
57
46
|
|
|
58
47
|
@classmethod
|
|
@@ -224,54 +213,6 @@ class BasePipeline:
|
|
|
224
213
|
model.eval()
|
|
225
214
|
return self
|
|
226
215
|
|
|
227
|
-
@staticmethod
|
|
228
|
-
def init_parallel_config(
|
|
229
|
-
parallelism: int,
|
|
230
|
-
use_cfg_parallel: bool,
|
|
231
|
-
model_config: ModelConfig,
|
|
232
|
-
):
|
|
233
|
-
assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
|
|
234
|
-
cfg_degree = 2 if use_cfg_parallel else 1
|
|
235
|
-
sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
|
|
236
|
-
sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
|
|
237
|
-
tp_degree = getattr(model_config, "tp_degree", None)
|
|
238
|
-
use_fsdp = getattr(model_config, "use_fsdp", False)
|
|
239
|
-
|
|
240
|
-
if tp_degree is not None:
|
|
241
|
-
assert sp_ulysses_degree is None and sp_ring_degree is None, (
|
|
242
|
-
"not allowed to enable sequence parallel and tensor parallel together; "
|
|
243
|
-
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
|
|
244
|
-
)
|
|
245
|
-
assert use_fsdp is False, (
|
|
246
|
-
"not allowed to enable fully sharded data parallel and tensor parallel together; "
|
|
247
|
-
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
|
|
248
|
-
)
|
|
249
|
-
assert parallelism == cfg_degree * tp_degree, (
|
|
250
|
-
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
|
|
251
|
-
)
|
|
252
|
-
sp_ulysses_degree = 1
|
|
253
|
-
sp_ring_degree = 1
|
|
254
|
-
elif sp_ulysses_degree is None and sp_ring_degree is None:
|
|
255
|
-
# use ulysses if not specified
|
|
256
|
-
sp_ulysses_degree = parallelism // cfg_degree
|
|
257
|
-
sp_ring_degree = 1
|
|
258
|
-
tp_degree = 1
|
|
259
|
-
elif sp_ulysses_degree is not None and sp_ring_degree is not None:
|
|
260
|
-
assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
|
|
261
|
-
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
|
|
262
|
-
f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
|
|
263
|
-
)
|
|
264
|
-
tp_degree = 1
|
|
265
|
-
else:
|
|
266
|
-
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
|
|
267
|
-
return {
|
|
268
|
-
"cfg_degree": cfg_degree,
|
|
269
|
-
"sp_ulysses_degree": sp_ulysses_degree,
|
|
270
|
-
"sp_ring_degree": sp_ring_degree,
|
|
271
|
-
"tp_degree": tp_degree,
|
|
272
|
-
"use_fsdp": use_fsdp,
|
|
273
|
-
}
|
|
274
|
-
|
|
275
216
|
def enable_cpu_offload(self, offload_mode: str):
|
|
276
217
|
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
|
|
277
218
|
if offload_mode not in valid_offload_mode:
|
|
@@ -326,14 +267,22 @@ class BasePipeline:
|
|
|
326
267
|
for model_name in self.model_names:
|
|
327
268
|
if model_name not in load_model_names:
|
|
328
269
|
model = getattr(self, model_name)
|
|
329
|
-
if
|
|
270
|
+
if (
|
|
271
|
+
model is not None
|
|
272
|
+
and (p := next(model.parameters(), None)) is not None
|
|
273
|
+
and p.device != torch.device("cpu")
|
|
274
|
+
):
|
|
330
275
|
param_cache = self._models_offload_params[model_name]
|
|
331
276
|
for name, param in model.named_parameters(recurse=True):
|
|
332
277
|
param.data = param_cache[name]
|
|
333
278
|
# load the needed models to device
|
|
334
279
|
for model_name in load_model_names:
|
|
335
280
|
model = getattr(self, model_name)
|
|
336
|
-
if
|
|
281
|
+
if (
|
|
282
|
+
model is not None
|
|
283
|
+
and (p := next(model.parameters(), None)) is not None
|
|
284
|
+
and p.device != torch.device(self.device)
|
|
285
|
+
):
|
|
337
286
|
model.to(self.device)
|
|
338
287
|
# fresh the cuda cache
|
|
339
288
|
empty_cache()
|
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import os
|
|
3
2
|
import json
|
|
4
3
|
import torch
|
|
5
4
|
import torch.distributed as dist
|
|
6
5
|
import math
|
|
7
6
|
from einops import rearrange
|
|
8
|
-
from enum import Enum
|
|
9
7
|
from typing import Callable, Dict, List, Tuple, Optional, Union
|
|
10
8
|
from tqdm import tqdm
|
|
11
9
|
from PIL import Image
|
|
12
|
-
from dataclasses import dataclass
|
|
13
10
|
from diffsynth_engine.models.flux import (
|
|
14
11
|
FluxTextEncoder1,
|
|
15
12
|
FluxTextEncoder2,
|
|
@@ -20,6 +17,7 @@ from diffsynth_engine.models.flux import (
|
|
|
20
17
|
flux_dit_config,
|
|
21
18
|
flux_text_encoder_config,
|
|
22
19
|
)
|
|
20
|
+
from diffsynth_engine.configs import FluxPipelineConfig, ControlType
|
|
23
21
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
24
22
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
25
23
|
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
|
|
@@ -416,48 +414,12 @@ def calculate_shift(
|
|
|
416
414
|
return mu
|
|
417
415
|
|
|
418
416
|
|
|
419
|
-
class ControlType(Enum):
|
|
420
|
-
normal = "normal"
|
|
421
|
-
bfl_control = "bfl_control"
|
|
422
|
-
bfl_fill = "bfl_fill"
|
|
423
|
-
bfl_kontext = "bfl_kontext"
|
|
424
|
-
|
|
425
|
-
def get_in_channel(self):
|
|
426
|
-
if self in [ControlType.normal, ControlType.bfl_kontext]:
|
|
427
|
-
return 64
|
|
428
|
-
elif self == ControlType.bfl_control:
|
|
429
|
-
return 128
|
|
430
|
-
elif self == ControlType.bfl_fill:
|
|
431
|
-
return 384
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
@dataclass
|
|
435
|
-
class FluxModelConfig:
|
|
436
|
-
dit_path: str | os.PathLike
|
|
437
|
-
clip_path: Optional[str | os.PathLike] = None
|
|
438
|
-
t5_path: Optional[str | os.PathLike] = None
|
|
439
|
-
vae_path: Optional[str | os.PathLike] = None
|
|
440
|
-
|
|
441
|
-
dit_dtype: torch.dtype = torch.bfloat16
|
|
442
|
-
clip_dtype: torch.dtype = torch.bfloat16
|
|
443
|
-
t5_dtype: torch.dtype = torch.bfloat16
|
|
444
|
-
vae_dtype: torch.dtype = torch.bfloat16
|
|
445
|
-
|
|
446
|
-
dit_attn_impl: Optional[str] = "auto"
|
|
447
|
-
use_fp8_linear: bool = False
|
|
448
|
-
|
|
449
|
-
sp_ulysses_degree: Optional[int] = None
|
|
450
|
-
sp_ring_degree: Optional[int] = None
|
|
451
|
-
tp_degree: Optional[int] = None
|
|
452
|
-
use_fsdp: bool = False
|
|
453
|
-
|
|
454
|
-
|
|
455
417
|
class FluxImagePipeline(BasePipeline):
|
|
456
418
|
lora_converter = FluxLoRAConverter()
|
|
457
419
|
|
|
458
420
|
def __init__(
|
|
459
421
|
self,
|
|
460
|
-
config:
|
|
422
|
+
config: FluxPipelineConfig,
|
|
461
423
|
tokenizer: CLIPTokenizer,
|
|
462
424
|
tokenizer_2: T5TokenizerFast,
|
|
463
425
|
text_encoder_1: FluxTextEncoder1,
|
|
@@ -465,23 +427,16 @@ class FluxImagePipeline(BasePipeline):
|
|
|
465
427
|
dit: Union[FluxDiT, FluxDiTFBCache],
|
|
466
428
|
vae_decoder: FluxVAEDecoder,
|
|
467
429
|
vae_encoder: FluxVAEEncoder,
|
|
468
|
-
load_text_encoder: bool = True,
|
|
469
|
-
batch_cfg: bool = False,
|
|
470
|
-
vae_tiled: bool = False,
|
|
471
|
-
vae_tile_size: int = 256,
|
|
472
|
-
vae_tile_stride: int = 256,
|
|
473
|
-
control_type: ControlType = ControlType.normal,
|
|
474
|
-
device: str = "cuda",
|
|
475
|
-
dtype: torch.dtype = torch.bfloat16,
|
|
476
430
|
):
|
|
477
431
|
super().__init__(
|
|
478
|
-
vae_tiled=vae_tiled,
|
|
479
|
-
vae_tile_size=vae_tile_size,
|
|
480
|
-
vae_tile_stride=vae_tile_stride,
|
|
481
|
-
device=device,
|
|
482
|
-
dtype=
|
|
432
|
+
vae_tiled=config.vae_tiled,
|
|
433
|
+
vae_tile_size=config.vae_tile_size,
|
|
434
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
435
|
+
device=config.device,
|
|
436
|
+
dtype=config.model_dtype,
|
|
483
437
|
)
|
|
484
438
|
self.config = config
|
|
439
|
+
# sampler
|
|
485
440
|
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
|
|
486
441
|
self.sampler = FlowMatchEulerSampler()
|
|
487
442
|
# models
|
|
@@ -492,11 +447,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
492
447
|
self.dit = dit
|
|
493
448
|
self.vae_decoder = vae_decoder
|
|
494
449
|
self.vae_encoder = vae_encoder
|
|
495
|
-
self.load_text_encoder = load_text_encoder
|
|
496
|
-
self.batch_cfg = batch_cfg
|
|
497
450
|
self.ip_adapter = None
|
|
498
451
|
self.redux = None
|
|
499
|
-
self.control_type = control_type
|
|
500
452
|
self.model_names = [
|
|
501
453
|
"text_encoder_1",
|
|
502
454
|
"text_encoder_2",
|
|
@@ -506,140 +458,115 @@ class FluxImagePipeline(BasePipeline):
|
|
|
506
458
|
]
|
|
507
459
|
|
|
508
460
|
@classmethod
|
|
509
|
-
def from_pretrained(
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
parallelism: int = 1,
|
|
522
|
-
use_cfg_parallel: bool = False,
|
|
523
|
-
use_fb_cache: bool = False,
|
|
524
|
-
fb_cache_relative_l1_threshold: float = 0.05,
|
|
525
|
-
) -> "FluxImagePipeline":
|
|
526
|
-
model_config = (
|
|
527
|
-
model_path_or_config
|
|
528
|
-
if isinstance(model_path_or_config, FluxModelConfig)
|
|
529
|
-
else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
|
|
530
|
-
)
|
|
531
|
-
if model_config.vae_path is None:
|
|
532
|
-
model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
|
|
533
|
-
|
|
534
|
-
if model_config.clip_path is None and load_text_encoder:
|
|
535
|
-
model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
|
|
536
|
-
if model_config.t5_path is None and load_text_encoder:
|
|
537
|
-
model_config.t5_path = fetch_model(
|
|
461
|
+
def from_pretrained(cls, model_path_or_config: str | FluxPipelineConfig) -> "FluxImagePipeline":
|
|
462
|
+
if isinstance(model_path_or_config, str):
|
|
463
|
+
config = FluxPipelineConfig(model_path=model_path_or_config)
|
|
464
|
+
else:
|
|
465
|
+
config = model_path_or_config
|
|
466
|
+
|
|
467
|
+
if config.vae_path is None:
|
|
468
|
+
config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
|
|
469
|
+
if config.clip_path is None and config.load_text_encoder:
|
|
470
|
+
config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
|
|
471
|
+
if config.t5_path is None and config.load_text_encoder:
|
|
472
|
+
config.t5_path = fetch_model(
|
|
538
473
|
"muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
|
|
539
474
|
)
|
|
540
475
|
|
|
541
|
-
logger.info(f"loading state dict from {
|
|
542
|
-
dit_state_dict = cls.load_model_checkpoint(
|
|
543
|
-
logger.info(f"loading state dict from {
|
|
544
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
545
|
-
if load_text_encoder:
|
|
546
|
-
logger.info(f"loading state dict from {
|
|
547
|
-
clip_state_dict = cls.load_model_checkpoint(
|
|
548
|
-
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
|
|
554
|
-
if load_text_encoder:
|
|
476
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
477
|
+
dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
478
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
479
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
480
|
+
if config.load_text_encoder:
|
|
481
|
+
logger.info(f"loading state dict from {config.clip_path} ...")
|
|
482
|
+
clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
|
|
483
|
+
logger.info(f"loading state dict from {config.t5_path} ...")
|
|
484
|
+
t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
|
|
485
|
+
|
|
486
|
+
init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
|
|
487
|
+
if config.load_text_encoder:
|
|
555
488
|
tokenizer = CLIPTokenizer.from_pretrained(FLUX_TOKENIZER_1_CONF_PATH)
|
|
556
489
|
tokenizer_2 = T5TokenizerFast.from_pretrained(FLUX_TOKENIZER_2_CONF_PATH)
|
|
557
490
|
with LoRAContext():
|
|
558
491
|
text_encoder_1 = FluxTextEncoder1.from_state_dict(
|
|
559
|
-
clip_state_dict, device=init_device, dtype=
|
|
492
|
+
clip_state_dict, device=init_device, dtype=config.clip_dtype
|
|
560
493
|
)
|
|
561
|
-
text_encoder_2 = FluxTextEncoder2.from_state_dict(
|
|
562
|
-
t5_state_dict, device=init_device, dtype=model_config.t5_dtype
|
|
563
|
-
)
|
|
494
|
+
text_encoder_2 = FluxTextEncoder2.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
|
|
564
495
|
|
|
565
|
-
vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=
|
|
566
|
-
vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=
|
|
496
|
+
vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
|
|
497
|
+
vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
|
|
567
498
|
|
|
568
499
|
with LoRAContext():
|
|
569
|
-
|
|
500
|
+
attn_kwargs = {
|
|
501
|
+
"attn_impl": config.dit_attn_impl,
|
|
502
|
+
"sparge_smooth_k": config.sparge_smooth_k,
|
|
503
|
+
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
504
|
+
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
505
|
+
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
506
|
+
}
|
|
507
|
+
if config.use_fbcache:
|
|
570
508
|
dit = FluxDiTFBCache.from_state_dict(
|
|
571
509
|
dit_state_dict,
|
|
572
510
|
device=init_device,
|
|
573
|
-
dtype=
|
|
574
|
-
in_channel=control_type.get_in_channel(),
|
|
575
|
-
|
|
576
|
-
relative_l1_threshold=
|
|
511
|
+
dtype=config.model_dtype,
|
|
512
|
+
in_channel=config.control_type.get_in_channel(),
|
|
513
|
+
attn_kwargs=attn_kwargs,
|
|
514
|
+
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
577
515
|
)
|
|
578
516
|
else:
|
|
579
517
|
dit = FluxDiT.from_state_dict(
|
|
580
518
|
dit_state_dict,
|
|
581
519
|
device=init_device,
|
|
582
|
-
dtype=
|
|
583
|
-
in_channel=control_type.get_in_channel(),
|
|
584
|
-
|
|
520
|
+
dtype=config.model_dtype,
|
|
521
|
+
in_channel=config.control_type.get_in_channel(),
|
|
522
|
+
attn_kwargs=attn_kwargs,
|
|
585
523
|
)
|
|
586
|
-
if
|
|
524
|
+
if config.use_fp8_linear:
|
|
587
525
|
enable_fp8_linear(dit)
|
|
588
526
|
|
|
589
527
|
pipe = cls(
|
|
590
|
-
config=
|
|
591
|
-
tokenizer=tokenizer if load_text_encoder else None,
|
|
592
|
-
tokenizer_2=tokenizer_2 if load_text_encoder else None,
|
|
593
|
-
text_encoder_1=text_encoder_1 if load_text_encoder else None,
|
|
594
|
-
text_encoder_2=text_encoder_2 if load_text_encoder else None,
|
|
528
|
+
config=config,
|
|
529
|
+
tokenizer=tokenizer if config.load_text_encoder else None,
|
|
530
|
+
tokenizer_2=tokenizer_2 if config.load_text_encoder else None,
|
|
531
|
+
text_encoder_1=text_encoder_1 if config.load_text_encoder else None,
|
|
532
|
+
text_encoder_2=text_encoder_2 if config.load_text_encoder else None,
|
|
595
533
|
dit=dit,
|
|
596
534
|
vae_decoder=vae_decoder,
|
|
597
535
|
vae_encoder=vae_encoder,
|
|
598
|
-
load_text_encoder=load_text_encoder,
|
|
599
|
-
batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
|
|
600
|
-
vae_tiled=vae_tiled,
|
|
601
|
-
vae_tile_size=vae_tile_size,
|
|
602
|
-
vae_tile_stride=vae_tile_stride,
|
|
603
|
-
control_type=control_type,
|
|
604
|
-
device=device,
|
|
605
|
-
dtype=model_config.dit_dtype,
|
|
606
536
|
)
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
if
|
|
610
|
-
pipe.
|
|
537
|
+
pipe.eval()
|
|
538
|
+
|
|
539
|
+
if config.offload_mode is not None:
|
|
540
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
541
|
+
|
|
542
|
+
if config.model_dtype == torch.float8_e4m3fn:
|
|
543
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
611
544
|
pipe.enable_fp8_autocast(
|
|
612
|
-
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=
|
|
545
|
+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
613
546
|
)
|
|
614
547
|
|
|
615
|
-
if
|
|
616
|
-
pipe.dtype = torch.bfloat16 #
|
|
548
|
+
if config.t5_dtype == torch.float8_e4m3fn:
|
|
549
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
617
550
|
pipe.enable_fp8_autocast(
|
|
618
|
-
model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=
|
|
551
|
+
model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
619
552
|
)
|
|
620
553
|
|
|
621
|
-
if parallelism > 1:
|
|
622
|
-
parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
|
|
623
|
-
cfg_degree = parallel_config["cfg_degree"]
|
|
624
|
-
sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
|
|
625
|
-
sp_ring_degree = parallel_config["sp_ring_degree"]
|
|
626
|
-
tp_degree = parallel_config["tp_degree"]
|
|
627
|
-
use_fsdp = parallel_config["use_fsdp"]
|
|
554
|
+
if config.parallelism > 1:
|
|
628
555
|
return ParallelWrapper(
|
|
629
556
|
pipe,
|
|
630
|
-
cfg_degree=cfg_degree,
|
|
631
|
-
sp_ulysses_degree=sp_ulysses_degree,
|
|
632
|
-
sp_ring_degree=sp_ring_degree,
|
|
633
|
-
tp_degree=tp_degree,
|
|
634
|
-
use_fsdp=use_fsdp,
|
|
557
|
+
cfg_degree=config.cfg_degree,
|
|
558
|
+
sp_ulysses_degree=config.sp_ulysses_degree,
|
|
559
|
+
sp_ring_degree=config.sp_ring_degree,
|
|
560
|
+
tp_degree=config.tp_degree,
|
|
561
|
+
use_fsdp=config.use_fsdp,
|
|
635
562
|
device="cuda",
|
|
636
563
|
)
|
|
637
564
|
return pipe
|
|
638
565
|
|
|
639
566
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
640
|
-
assert self.config.tp_degree is None, (
|
|
567
|
+
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
641
568
|
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
642
|
-
"set tp_degree=None during pipeline initialization"
|
|
569
|
+
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
643
570
|
)
|
|
644
571
|
assert not (self.config.use_fsdp and fused), (
|
|
645
572
|
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
@@ -791,9 +718,9 @@ class FluxImagePipeline(BasePipeline):
|
|
|
791
718
|
total_step: int,
|
|
792
719
|
):
|
|
793
720
|
origin_latents_shape = latents.shape
|
|
794
|
-
if self.control_type != ControlType.normal:
|
|
721
|
+
if self.config.control_type != ControlType.normal:
|
|
795
722
|
controlnet_param = controlnet_params[0]
|
|
796
|
-
if self.control_type == ControlType.bfl_kontext:
|
|
723
|
+
if self.config.control_type == ControlType.bfl_kontext:
|
|
797
724
|
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
|
|
798
725
|
image_ids = image_ids.repeat(1, 2, 1)
|
|
799
726
|
image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
|
|
@@ -828,7 +755,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
828
755
|
controlnet_double_block_output=double_block_output,
|
|
829
756
|
controlnet_single_block_output=single_block_output,
|
|
830
757
|
)
|
|
831
|
-
if self.control_type == ControlType.bfl_kontext:
|
|
758
|
+
if self.config.control_type == ControlType.bfl_kontext:
|
|
832
759
|
noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
|
|
833
760
|
return noise_pred
|
|
834
761
|
|
|
@@ -877,7 +804,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
877
804
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
878
805
|
latent = self.encode_image(image)
|
|
879
806
|
else:
|
|
880
|
-
if self.control_type == ControlType.normal:
|
|
807
|
+
if self.config.control_type == ControlType.normal:
|
|
881
808
|
image = image.resize((width, height))
|
|
882
809
|
mask = mask.resize((width, height))
|
|
883
810
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
@@ -888,7 +815,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
888
815
|
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
|
|
889
816
|
mask = 1 - mask
|
|
890
817
|
latent = torch.cat([latent, mask], dim=1)
|
|
891
|
-
elif self.control_type == ControlType.bfl_fill:
|
|
818
|
+
elif self.config.control_type == ControlType.bfl_fill:
|
|
892
819
|
image = image.resize((width, height))
|
|
893
820
|
mask = mask.resize((width, height))
|
|
894
821
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
@@ -898,7 +825,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
898
825
|
mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
|
899
826
|
latent = torch.cat((image, mask), dim=1)
|
|
900
827
|
else:
|
|
901
|
-
raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.control_type}")
|
|
828
|
+
raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.config.control_type}")
|
|
902
829
|
return latent
|
|
903
830
|
|
|
904
831
|
def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
|
|
@@ -996,7 +923,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
996
923
|
self.dit.refresh_cache_status(num_inference_steps)
|
|
997
924
|
if not isinstance(controlnet_params, list):
|
|
998
925
|
controlnet_params = [controlnet_params]
|
|
999
|
-
if self.control_type != ControlType.normal:
|
|
926
|
+
if self.config.control_type != ControlType.normal:
|
|
1000
927
|
assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
|
|
1001
928
|
|
|
1002
929
|
if input_image is not None:
|
|
@@ -1033,7 +960,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
1033
960
|
if self.ip_adapter is not None and self.redux is not None:
|
|
1034
961
|
raise Exception("ip-adapter and flux redux cannot be used at the same time")
|
|
1035
962
|
elif self.ip_adapter is not None:
|
|
1036
|
-
image_emb = self.ip_adapter(ref_image)
|
|
963
|
+
image_emb = self.ip_adapter.encode_image(ref_image)
|
|
1037
964
|
elif self.redux is not None:
|
|
1038
965
|
image_prompt_embeds = self.redux(ref_image)
|
|
1039
966
|
positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
|
|
@@ -1063,7 +990,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
1063
990
|
controlnet_params=controlnet_params,
|
|
1064
991
|
current_step=i,
|
|
1065
992
|
total_step=len(timesteps),
|
|
1066
|
-
batch_cfg=self.batch_cfg,
|
|
993
|
+
batch_cfg=self.config.batch_cfg,
|
|
1067
994
|
)
|
|
1068
995
|
# Denoise
|
|
1069
996
|
latents = self.sampler.step(latents, noise_pred, i)
|