diffsynth-engine 0.3.6.dev8__py3-none-any.whl → 0.3.6.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. diffsynth_engine/__init__.py +10 -8
  2. diffsynth_engine/configs/__init__.py +23 -0
  3. diffsynth_engine/configs/controlnet.py +17 -0
  4. diffsynth_engine/configs/pipeline.py +206 -0
  5. diffsynth_engine/models/basic/attention.py +43 -4
  6. diffsynth_engine/models/flux/flux_controlnet.py +8 -5
  7. diffsynth_engine/models/flux/flux_dit.py +22 -16
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +5 -5
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/sd/sd_controlnet.py +2 -4
  11. diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
  12. diffsynth_engine/models/wan/wan_dit.py +15 -15
  13. diffsynth_engine/pipelines/__init__.py +5 -8
  14. diffsynth_engine/pipelines/base.py +14 -65
  15. diffsynth_engine/pipelines/flux_image.py +85 -158
  16. diffsynth_engine/pipelines/sd_image.py +30 -64
  17. diffsynth_engine/pipelines/sdxl_image.py +39 -71
  18. diffsynth_engine/pipelines/wan_video.py +66 -105
  19. diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
  20. diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
  21. diffsynth_engine/tools/flux_reference_tool.py +21 -5
  22. diffsynth_engine/tools/flux_replace_tool.py +15 -3
  23. diffsynth_engine/utils/fp8_linear.py +14 -5
  24. diffsynth_engine/utils/parallel.py +1 -1
  25. diffsynth_engine/utils/platform.py +9 -1
  26. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/METADATA +1 -1
  27. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/RECORD +30 -27
  28. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/WHEEL +0 -0
  29. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/licenses/LICENSE +0 -0
  30. {diffsynth_engine-0.3.6.dev8.dist-info → diffsynth_engine-0.3.6.dev10.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import math
2
2
  import json
3
3
  import torch
4
4
  import torch.nn as nn
5
- from typing import Tuple, Optional
5
+ from typing import Any, Dict, Tuple, Optional
6
6
  from einops import rearrange
7
7
 
8
8
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
@@ -69,7 +69,7 @@ class SelfAttention(nn.Module):
69
69
  dim: int,
70
70
  num_heads: int,
71
71
  eps: float = 1e-6,
72
- attn_impl: Optional[str] = None,
72
+ attn_kwargs: Optional[Dict[str, Any]] = None,
73
73
  device: str = "cuda:0",
74
74
  dtype: torch.dtype = torch.bfloat16,
75
75
  ):
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
82
82
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
83
83
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
84
84
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
85
- self.attn_impl = attn_impl
85
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
86
86
 
87
87
  def forward(self, x, freqs):
88
88
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
@@ -94,7 +94,7 @@ class SelfAttention(nn.Module):
94
94
  q=rope_apply(q, freqs),
95
95
  k=rope_apply(k, freqs),
96
96
  v=v,
97
- attn_impl=self.attn_impl,
97
+ **self.attn_kwargs,
98
98
  )
99
99
  x = x.flatten(2)
100
100
  return self.o(x)
@@ -107,7 +107,7 @@ class CrossAttention(nn.Module):
107
107
  num_heads: int,
108
108
  eps: float = 1e-6,
109
109
  has_image_input: bool = False,
110
- attn_impl: Optional[str] = None,
110
+ attn_kwargs: Optional[Dict[str, Any]] = None,
111
111
  device: str = "cuda:0",
112
112
  dtype: torch.dtype = torch.bfloat16,
113
113
  ):
@@ -126,7 +126,7 @@ class CrossAttention(nn.Module):
126
126
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
127
127
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
128
128
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
129
- self.attn_impl = attn_impl
129
+ self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
130
130
 
131
131
  def forward(self, x: torch.Tensor, y: torch.Tensor):
132
132
  if self.has_image_input:
@@ -140,12 +140,12 @@ class CrossAttention(nn.Module):
140
140
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
141
141
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
142
142
 
143
- x = attention_ops.attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
143
+ x = attention_ops.attention(q, k, v, **self.attn_kwargs).flatten(2)
144
144
  if self.has_image_input:
145
145
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
146
146
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
147
147
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
148
- y = attention_ops.attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
148
+ y = attention_ops.attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
149
149
  x = x + y
150
150
  return self.o(x)
151
151
 
@@ -158,7 +158,7 @@ class DiTBlock(nn.Module):
158
158
  num_heads: int,
159
159
  ffn_dim: int,
160
160
  eps: float = 1e-6,
161
- attn_impl: Optional[str] = None,
161
+ attn_kwargs: Optional[Dict[str, Any]] = None,
162
162
  device: str = "cuda:0",
163
163
  dtype: torch.dtype = torch.bfloat16,
164
164
  ):
@@ -166,9 +166,9 @@ class DiTBlock(nn.Module):
166
166
  self.dim = dim
167
167
  self.num_heads = num_heads
168
168
  self.ffn_dim = ffn_dim
169
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_impl=attn_impl, device=device, dtype=dtype)
169
+ self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
170
170
  self.cross_attn = CrossAttention(
171
- dim, num_heads, eps, has_image_input=has_image_input, attn_impl=attn_impl, device=device, dtype=dtype
171
+ dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
172
172
  )
173
173
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
174
174
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -265,7 +265,7 @@ class WanDiT(PreTrainedModel):
265
265
  num_layers: int,
266
266
  has_image_input: bool,
267
267
  flf_pos_emb: bool = False,
268
- attn_impl: Optional[str] = None,
268
+ attn_kwargs: Optional[Dict[str, Any]] = None,
269
269
  device: str = "cpu",
270
270
  dtype: torch.dtype = torch.bfloat16,
271
271
  ):
@@ -296,7 +296,7 @@ class WanDiT(PreTrainedModel):
296
296
  )
297
297
  self.blocks = nn.ModuleList(
298
298
  [
299
- DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_impl, device=device, dtype=dtype)
299
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
300
300
  for _ in range(num_layers)
301
301
  ]
302
302
  )
@@ -376,7 +376,7 @@ class WanDiT(PreTrainedModel):
376
376
  device,
377
377
  dtype,
378
378
  model_type="1.3b-t2v",
379
- attn_impl: Optional[str] = None,
379
+ attn_kwargs: Optional[Dict[str, Any]] = None,
380
380
  assign=True,
381
381
  ):
382
382
  if model_type == "1.3b-t2v":
@@ -390,7 +390,7 @@ class WanDiT(PreTrainedModel):
390
390
  else:
391
391
  raise ValueError(f"Unsupported model type: {model_type}")
392
392
  with no_init_weights():
393
- model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_impl=attn_impl)
393
+ model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_kwargs=attn_kwargs)
394
394
  model = model.requires_grad_(False)
395
395
  model.load_state_dict(state_dict, assign=assign)
396
396
  model.to(device=device, dtype=dtype)
@@ -1,20 +1,17 @@
1
1
  from .base import BasePipeline, LoRAStateDictConverter
2
2
  from .controlnet_helper import ControlNetParams
3
- from .flux_image import FluxImagePipeline, FluxModelConfig
4
- from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
5
- from .sd_image import SDImagePipeline, SDModelConfig
6
- from .wan_video import WanVideoPipeline, WanModelConfig
3
+ from .flux_image import FluxImagePipeline
4
+ from .sdxl_image import SDXLImagePipeline
5
+ from .sd_image import SDImagePipeline
6
+ from .wan_video import WanVideoPipeline
7
+
7
8
 
8
9
  __all__ = [
9
10
  "BasePipeline",
10
11
  "LoRAStateDictConverter",
11
12
  "FluxImagePipeline",
12
- "FluxModelConfig",
13
13
  "SDXLImagePipeline",
14
- "SDXLModelConfig",
15
14
  "SDImagePipeline",
16
- "SDModelConfig",
17
15
  "WanVideoPipeline",
18
- "WanModelConfig",
19
16
  "ControlNetParams",
20
17
  ]
@@ -3,7 +3,7 @@ import torch
3
3
  import numpy as np
4
4
  from typing import Dict, List, Tuple
5
5
  from PIL import Image
6
- from dataclasses import dataclass
6
+ from diffsynth_engine.configs import BaseConfig
7
7
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
8
8
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
9
9
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -14,11 +14,6 @@ from diffsynth_engine.utils.platform import empty_cache
14
14
  logger = logging.get_logger(__name__)
15
15
 
16
16
 
17
- @dataclass
18
- class ModelConfig:
19
- pass
20
-
21
-
22
17
  class LoRAStateDictConverter:
23
18
  def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
24
19
  return {"lora": lora_state_dict}
@@ -30,8 +25,8 @@ class BasePipeline:
30
25
  def __init__(
31
26
  self,
32
27
  vae_tiled: bool = False,
33
- vae_tile_size: int = -1,
34
- vae_tile_stride: int = -1,
28
+ vae_tile_size: int | Tuple[int, int] = -1,
29
+ vae_tile_stride: int | Tuple[int, int] = -1,
35
30
  device="cuda",
36
31
  dtype=torch.float16,
37
32
  ):
@@ -46,13 +41,7 @@ class BasePipeline:
46
41
  self._models_offload_params = {}
47
42
 
48
43
  @classmethod
49
- def from_pretrained(
50
- cls,
51
- model_path_or_config: str | os.PathLike | ModelConfig,
52
- device: str = "cuda",
53
- dtype: torch.dtype = torch.float16,
54
- offload_mode: str | None = None,
55
- ) -> "BasePipeline":
44
+ def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
56
45
  raise NotImplementedError()
57
46
 
58
47
  @classmethod
@@ -224,54 +213,6 @@ class BasePipeline:
224
213
  model.eval()
225
214
  return self
226
215
 
227
- @staticmethod
228
- def init_parallel_config(
229
- parallelism: int,
230
- use_cfg_parallel: bool,
231
- model_config: ModelConfig,
232
- ):
233
- assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
234
- cfg_degree = 2 if use_cfg_parallel else 1
235
- sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
236
- sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
237
- tp_degree = getattr(model_config, "tp_degree", None)
238
- use_fsdp = getattr(model_config, "use_fsdp", False)
239
-
240
- if tp_degree is not None:
241
- assert sp_ulysses_degree is None and sp_ring_degree is None, (
242
- "not allowed to enable sequence parallel and tensor parallel together; "
243
- "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
244
- )
245
- assert use_fsdp is False, (
246
- "not allowed to enable fully sharded data parallel and tensor parallel together; "
247
- "either set use_fsdp=False or set tp_degree=None during pipeline initialization"
248
- )
249
- assert parallelism == cfg_degree * tp_degree, (
250
- f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
251
- )
252
- sp_ulysses_degree = 1
253
- sp_ring_degree = 1
254
- elif sp_ulysses_degree is None and sp_ring_degree is None:
255
- # use ulysses if not specified
256
- sp_ulysses_degree = parallelism // cfg_degree
257
- sp_ring_degree = 1
258
- tp_degree = 1
259
- elif sp_ulysses_degree is not None and sp_ring_degree is not None:
260
- assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
261
- f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
262
- f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
263
- )
264
- tp_degree = 1
265
- else:
266
- raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
267
- return {
268
- "cfg_degree": cfg_degree,
269
- "sp_ulysses_degree": sp_ulysses_degree,
270
- "sp_ring_degree": sp_ring_degree,
271
- "tp_degree": tp_degree,
272
- "use_fsdp": use_fsdp,
273
- }
274
-
275
216
  def enable_cpu_offload(self, offload_mode: str):
276
217
  valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
277
218
  if offload_mode not in valid_offload_mode:
@@ -326,14 +267,22 @@ class BasePipeline:
326
267
  for model_name in self.model_names:
327
268
  if model_name not in load_model_names:
328
269
  model = getattr(self, model_name)
329
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device("cpu"):
270
+ if (
271
+ model is not None
272
+ and (p := next(model.parameters(), None)) is not None
273
+ and p.device != torch.device("cpu")
274
+ ):
330
275
  param_cache = self._models_offload_params[model_name]
331
276
  for name, param in model.named_parameters(recurse=True):
332
277
  param.data = param_cache[name]
333
278
  # load the needed models to device
334
279
  for model_name in load_model_names:
335
280
  model = getattr(self, model_name)
336
- if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device(self.device):
281
+ if (
282
+ model is not None
283
+ and (p := next(model.parameters(), None)) is not None
284
+ and p.device != torch.device(self.device)
285
+ ):
337
286
  model.to(self.device)
338
287
  # fresh the cuda cache
339
288
  empty_cache()
@@ -1,15 +1,12 @@
1
1
  import re
2
- import os
3
2
  import json
4
3
  import torch
5
4
  import torch.distributed as dist
6
5
  import math
7
6
  from einops import rearrange
8
- from enum import Enum
9
7
  from typing import Callable, Dict, List, Tuple, Optional, Union
10
8
  from tqdm import tqdm
11
9
  from PIL import Image
12
- from dataclasses import dataclass
13
10
  from diffsynth_engine.models.flux import (
14
11
  FluxTextEncoder1,
15
12
  FluxTextEncoder2,
@@ -20,6 +17,7 @@ from diffsynth_engine.models.flux import (
20
17
  flux_dit_config,
21
18
  flux_text_encoder_config,
22
19
  )
20
+ from diffsynth_engine.configs import FluxPipelineConfig, ControlType
23
21
  from diffsynth_engine.models.basic.lora import LoRAContext
24
22
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
25
23
  from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
@@ -416,48 +414,12 @@ def calculate_shift(
416
414
  return mu
417
415
 
418
416
 
419
- class ControlType(Enum):
420
- normal = "normal"
421
- bfl_control = "bfl_control"
422
- bfl_fill = "bfl_fill"
423
- bfl_kontext = "bfl_kontext"
424
-
425
- def get_in_channel(self):
426
- if self in [ControlType.normal, ControlType.bfl_kontext]:
427
- return 64
428
- elif self == ControlType.bfl_control:
429
- return 128
430
- elif self == ControlType.bfl_fill:
431
- return 384
432
-
433
-
434
- @dataclass
435
- class FluxModelConfig:
436
- dit_path: str | os.PathLike
437
- clip_path: Optional[str | os.PathLike] = None
438
- t5_path: Optional[str | os.PathLike] = None
439
- vae_path: Optional[str | os.PathLike] = None
440
-
441
- dit_dtype: torch.dtype = torch.bfloat16
442
- clip_dtype: torch.dtype = torch.bfloat16
443
- t5_dtype: torch.dtype = torch.bfloat16
444
- vae_dtype: torch.dtype = torch.bfloat16
445
-
446
- dit_attn_impl: Optional[str] = "auto"
447
- use_fp8_linear: bool = False
448
-
449
- sp_ulysses_degree: Optional[int] = None
450
- sp_ring_degree: Optional[int] = None
451
- tp_degree: Optional[int] = None
452
- use_fsdp: bool = False
453
-
454
-
455
417
  class FluxImagePipeline(BasePipeline):
456
418
  lora_converter = FluxLoRAConverter()
457
419
 
458
420
  def __init__(
459
421
  self,
460
- config: FluxModelConfig,
422
+ config: FluxPipelineConfig,
461
423
  tokenizer: CLIPTokenizer,
462
424
  tokenizer_2: T5TokenizerFast,
463
425
  text_encoder_1: FluxTextEncoder1,
@@ -465,23 +427,16 @@ class FluxImagePipeline(BasePipeline):
465
427
  dit: Union[FluxDiT, FluxDiTFBCache],
466
428
  vae_decoder: FluxVAEDecoder,
467
429
  vae_encoder: FluxVAEEncoder,
468
- load_text_encoder: bool = True,
469
- batch_cfg: bool = False,
470
- vae_tiled: bool = False,
471
- vae_tile_size: int = 256,
472
- vae_tile_stride: int = 256,
473
- control_type: ControlType = ControlType.normal,
474
- device: str = "cuda",
475
- dtype: torch.dtype = torch.bfloat16,
476
430
  ):
477
431
  super().__init__(
478
- vae_tiled=vae_tiled,
479
- vae_tile_size=vae_tile_size,
480
- vae_tile_stride=vae_tile_stride,
481
- device=device,
482
- dtype=dtype,
432
+ vae_tiled=config.vae_tiled,
433
+ vae_tile_size=config.vae_tile_size,
434
+ vae_tile_stride=config.vae_tile_stride,
435
+ device=config.device,
436
+ dtype=config.model_dtype,
483
437
  )
484
438
  self.config = config
439
+ # sampler
485
440
  self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
486
441
  self.sampler = FlowMatchEulerSampler()
487
442
  # models
@@ -492,11 +447,8 @@ class FluxImagePipeline(BasePipeline):
492
447
  self.dit = dit
493
448
  self.vae_decoder = vae_decoder
494
449
  self.vae_encoder = vae_encoder
495
- self.load_text_encoder = load_text_encoder
496
- self.batch_cfg = batch_cfg
497
450
  self.ip_adapter = None
498
451
  self.redux = None
499
- self.control_type = control_type
500
452
  self.model_names = [
501
453
  "text_encoder_1",
502
454
  "text_encoder_2",
@@ -506,140 +458,115 @@ class FluxImagePipeline(BasePipeline):
506
458
  ]
507
459
 
508
460
  @classmethod
509
- def from_pretrained(
510
- cls,
511
- model_path_or_config: str | os.PathLike | FluxModelConfig,
512
- load_text_encoder: bool = True,
513
- batch_cfg: bool = False,
514
- vae_tiled: bool = False,
515
- vae_tile_size: int = 256,
516
- vae_tile_stride: int = 256,
517
- control_type: ControlType = ControlType.normal,
518
- device: str = "cuda",
519
- dtype: torch.dtype = torch.bfloat16,
520
- offload_mode: str | None = None,
521
- parallelism: int = 1,
522
- use_cfg_parallel: bool = False,
523
- use_fb_cache: bool = False,
524
- fb_cache_relative_l1_threshold: float = 0.05,
525
- ) -> "FluxImagePipeline":
526
- model_config = (
527
- model_path_or_config
528
- if isinstance(model_path_or_config, FluxModelConfig)
529
- else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
530
- )
531
- if model_config.vae_path is None:
532
- model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
533
-
534
- if model_config.clip_path is None and load_text_encoder:
535
- model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
536
- if model_config.t5_path is None and load_text_encoder:
537
- model_config.t5_path = fetch_model(
461
+ def from_pretrained(cls, model_path_or_config: str | FluxPipelineConfig) -> "FluxImagePipeline":
462
+ if isinstance(model_path_or_config, str):
463
+ config = FluxPipelineConfig(model_path=model_path_or_config)
464
+ else:
465
+ config = model_path_or_config
466
+
467
+ if config.vae_path is None:
468
+ config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
469
+ if config.clip_path is None and config.load_text_encoder:
470
+ config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
471
+ if config.t5_path is None and config.load_text_encoder:
472
+ config.t5_path = fetch_model(
538
473
  "muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
539
474
  )
540
475
 
541
- logger.info(f"loading state dict from {model_config.dit_path} ...")
542
- dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=model_config.dit_dtype)
543
- logger.info(f"loading state dict from {model_config.vae_path} ...")
544
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
545
- if load_text_encoder:
546
- logger.info(f"loading state dict from {model_config.clip_path} ...")
547
- clip_state_dict = cls.load_model_checkpoint(
548
- model_config.clip_path, device="cpu", dtype=model_config.clip_dtype
549
- )
550
- logger.info(f"loading state dict from {model_config.t5_path} ...")
551
- t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
552
-
553
- init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
554
- if load_text_encoder:
476
+ logger.info(f"loading state dict from {config.model_path} ...")
477
+ dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
478
+ logger.info(f"loading state dict from {config.vae_path} ...")
479
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
480
+ if config.load_text_encoder:
481
+ logger.info(f"loading state dict from {config.clip_path} ...")
482
+ clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
483
+ logger.info(f"loading state dict from {config.t5_path} ...")
484
+ t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
485
+
486
+ init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
487
+ if config.load_text_encoder:
555
488
  tokenizer = CLIPTokenizer.from_pretrained(FLUX_TOKENIZER_1_CONF_PATH)
556
489
  tokenizer_2 = T5TokenizerFast.from_pretrained(FLUX_TOKENIZER_2_CONF_PATH)
557
490
  with LoRAContext():
558
491
  text_encoder_1 = FluxTextEncoder1.from_state_dict(
559
- clip_state_dict, device=init_device, dtype=model_config.clip_dtype
492
+ clip_state_dict, device=init_device, dtype=config.clip_dtype
560
493
  )
561
- text_encoder_2 = FluxTextEncoder2.from_state_dict(
562
- t5_state_dict, device=init_device, dtype=model_config.t5_dtype
563
- )
494
+ text_encoder_2 = FluxTextEncoder2.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
564
495
 
565
- vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
566
- vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
496
+ vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
497
+ vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
567
498
 
568
499
  with LoRAContext():
569
- if use_fb_cache:
500
+ attn_kwargs = {
501
+ "attn_impl": config.dit_attn_impl,
502
+ "sparge_smooth_k": config.sparge_smooth_k,
503
+ "sparge_cdfthreshd": config.sparge_cdfthreshd,
504
+ "sparge_simthreshd1": config.sparge_simthreshd1,
505
+ "sparge_pvthreshd": config.sparge_pvthreshd,
506
+ }
507
+ if config.use_fbcache:
570
508
  dit = FluxDiTFBCache.from_state_dict(
571
509
  dit_state_dict,
572
510
  device=init_device,
573
- dtype=model_config.dit_dtype,
574
- in_channel=control_type.get_in_channel(),
575
- attn_impl=model_config.dit_attn_impl,
576
- relative_l1_threshold=fb_cache_relative_l1_threshold,
511
+ dtype=config.model_dtype,
512
+ in_channel=config.control_type.get_in_channel(),
513
+ attn_kwargs=attn_kwargs,
514
+ relative_l1_threshold=config.fbcache_relative_l1_threshold,
577
515
  )
578
516
  else:
579
517
  dit = FluxDiT.from_state_dict(
580
518
  dit_state_dict,
581
519
  device=init_device,
582
- dtype=model_config.dit_dtype,
583
- in_channel=control_type.get_in_channel(),
584
- attn_impl=model_config.dit_attn_impl,
520
+ dtype=config.model_dtype,
521
+ in_channel=config.control_type.get_in_channel(),
522
+ attn_kwargs=attn_kwargs,
585
523
  )
586
- if model_config.use_fp8_linear:
524
+ if config.use_fp8_linear:
587
525
  enable_fp8_linear(dit)
588
526
 
589
527
  pipe = cls(
590
- config=model_config,
591
- tokenizer=tokenizer if load_text_encoder else None,
592
- tokenizer_2=tokenizer_2 if load_text_encoder else None,
593
- text_encoder_1=text_encoder_1 if load_text_encoder else None,
594
- text_encoder_2=text_encoder_2 if load_text_encoder else None,
528
+ config=config,
529
+ tokenizer=tokenizer if config.load_text_encoder else None,
530
+ tokenizer_2=tokenizer_2 if config.load_text_encoder else None,
531
+ text_encoder_1=text_encoder_1 if config.load_text_encoder else None,
532
+ text_encoder_2=text_encoder_2 if config.load_text_encoder else None,
595
533
  dit=dit,
596
534
  vae_decoder=vae_decoder,
597
535
  vae_encoder=vae_encoder,
598
- load_text_encoder=load_text_encoder,
599
- batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
600
- vae_tiled=vae_tiled,
601
- vae_tile_size=vae_tile_size,
602
- vae_tile_stride=vae_tile_stride,
603
- control_type=control_type,
604
- device=device,
605
- dtype=model_config.dit_dtype,
606
536
  )
607
- if offload_mode is not None:
608
- pipe.enable_cpu_offload(offload_mode)
609
- if model_config.dit_dtype == torch.float8_e4m3fn:
610
- pipe.dtype = torch.bfloat16 # running dtype
537
+ pipe.eval()
538
+
539
+ if config.offload_mode is not None:
540
+ pipe.enable_cpu_offload(config.offload_mode)
541
+
542
+ if config.model_dtype == torch.float8_e4m3fn:
543
+ pipe.dtype = torch.bfloat16 # compute dtype
611
544
  pipe.enable_fp8_autocast(
612
- model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
545
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
613
546
  )
614
547
 
615
- if model_config.t5_dtype == torch.float8_e4m3fn:
616
- pipe.dtype = torch.bfloat16 # running dtype
548
+ if config.t5_dtype == torch.float8_e4m3fn:
549
+ pipe.dtype = torch.bfloat16 # compute dtype
617
550
  pipe.enable_fp8_autocast(
618
- model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
551
+ model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
619
552
  )
620
553
 
621
- if parallelism > 1:
622
- parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
623
- cfg_degree = parallel_config["cfg_degree"]
624
- sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
625
- sp_ring_degree = parallel_config["sp_ring_degree"]
626
- tp_degree = parallel_config["tp_degree"]
627
- use_fsdp = parallel_config["use_fsdp"]
554
+ if config.parallelism > 1:
628
555
  return ParallelWrapper(
629
556
  pipe,
630
- cfg_degree=cfg_degree,
631
- sp_ulysses_degree=sp_ulysses_degree,
632
- sp_ring_degree=sp_ring_degree,
633
- tp_degree=tp_degree,
634
- use_fsdp=use_fsdp,
557
+ cfg_degree=config.cfg_degree,
558
+ sp_ulysses_degree=config.sp_ulysses_degree,
559
+ sp_ring_degree=config.sp_ring_degree,
560
+ tp_degree=config.tp_degree,
561
+ use_fsdp=config.use_fsdp,
635
562
  device="cuda",
636
563
  )
637
564
  return pipe
638
565
 
639
566
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
640
- assert self.config.tp_degree is None, (
567
+ assert self.config.tp_degree is None or self.config.tp_degree == 1, (
641
568
  "load LoRA is not allowed when tensor parallel is enabled; "
642
- "set tp_degree=None during pipeline initialization"
569
+ "set tp_degree=None or tp_degree=1 during pipeline initialization"
643
570
  )
644
571
  assert not (self.config.use_fsdp and fused), (
645
572
  "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
@@ -791,9 +718,9 @@ class FluxImagePipeline(BasePipeline):
791
718
  total_step: int,
792
719
  ):
793
720
  origin_latents_shape = latents.shape
794
- if self.control_type != ControlType.normal:
721
+ if self.config.control_type != ControlType.normal:
795
722
  controlnet_param = controlnet_params[0]
796
- if self.control_type == ControlType.bfl_kontext:
723
+ if self.config.control_type == ControlType.bfl_kontext:
797
724
  latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
798
725
  image_ids = image_ids.repeat(1, 2, 1)
799
726
  image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
@@ -828,7 +755,7 @@ class FluxImagePipeline(BasePipeline):
828
755
  controlnet_double_block_output=double_block_output,
829
756
  controlnet_single_block_output=single_block_output,
830
757
  )
831
- if self.control_type == ControlType.bfl_kontext:
758
+ if self.config.control_type == ControlType.bfl_kontext:
832
759
  noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
833
760
  return noise_pred
834
761
 
@@ -877,7 +804,7 @@ class FluxImagePipeline(BasePipeline):
877
804
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
878
805
  latent = self.encode_image(image)
879
806
  else:
880
- if self.control_type == ControlType.normal:
807
+ if self.config.control_type == ControlType.normal:
881
808
  image = image.resize((width, height))
882
809
  mask = mask.resize((width, height))
883
810
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
@@ -888,7 +815,7 @@ class FluxImagePipeline(BasePipeline):
888
815
  mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
889
816
  mask = 1 - mask
890
817
  latent = torch.cat([latent, mask], dim=1)
891
- elif self.control_type == ControlType.bfl_fill:
818
+ elif self.config.control_type == ControlType.bfl_fill:
892
819
  image = image.resize((width, height))
893
820
  mask = mask.resize((width, height))
894
821
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
@@ -898,7 +825,7 @@ class FluxImagePipeline(BasePipeline):
898
825
  mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
899
826
  latent = torch.cat((image, mask), dim=1)
900
827
  else:
901
- raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.control_type}")
828
+ raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.config.control_type}")
902
829
  return latent
903
830
 
904
831
  def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
@@ -996,7 +923,7 @@ class FluxImagePipeline(BasePipeline):
996
923
  self.dit.refresh_cache_status(num_inference_steps)
997
924
  if not isinstance(controlnet_params, list):
998
925
  controlnet_params = [controlnet_params]
999
- if self.control_type != ControlType.normal:
926
+ if self.config.control_type != ControlType.normal:
1000
927
  assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
1001
928
 
1002
929
  if input_image is not None:
@@ -1033,7 +960,7 @@ class FluxImagePipeline(BasePipeline):
1033
960
  if self.ip_adapter is not None and self.redux is not None:
1034
961
  raise Exception("ip-adapter and flux redux cannot be used at the same time")
1035
962
  elif self.ip_adapter is not None:
1036
- image_emb = self.ip_adapter(ref_image)
963
+ image_emb = self.ip_adapter.encode_image(ref_image)
1037
964
  elif self.redux is not None:
1038
965
  image_prompt_embeds = self.redux(ref_image)
1039
966
  positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
@@ -1063,7 +990,7 @@ class FluxImagePipeline(BasePipeline):
1063
990
  controlnet_params=controlnet_params,
1064
991
  current_step=i,
1065
992
  total_step=len(timesteps),
1066
- batch_cfg=self.batch_cfg,
993
+ batch_cfg=self.config.batch_cfg,
1067
994
  )
1068
995
  # Denoise
1069
996
  latents = self.sampler.step(latents, noise_pred, i)