diffsynth-engine 0.3.6.dev9__py3-none-any.whl → 0.3.6.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +10 -8
- diffsynth_engine/configs/__init__.py +23 -0
- diffsynth_engine/configs/controlnet.py +17 -0
- diffsynth_engine/configs/pipeline.py +206 -0
- diffsynth_engine/models/basic/attention.py +43 -4
- diffsynth_engine/models/flux/flux_controlnet.py +8 -5
- diffsynth_engine/models/flux/flux_dit.py +22 -16
- diffsynth_engine/models/flux/flux_dit_fbcache.py +7 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/sd/sd_controlnet.py +2 -4
- diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
- diffsynth_engine/models/wan/wan_dit.py +15 -15
- diffsynth_engine/pipelines/__init__.py +5 -8
- diffsynth_engine/pipelines/base.py +14 -65
- diffsynth_engine/pipelines/flux_image.py +85 -158
- diffsynth_engine/pipelines/sd_image.py +30 -64
- diffsynth_engine/pipelines/sdxl_image.py +39 -71
- diffsynth_engine/pipelines/wan_video.py +66 -105
- diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
- diffsynth_engine/tools/flux_reference_tool.py +21 -5
- diffsynth_engine/tools/flux_replace_tool.py +15 -3
- diffsynth_engine/utils/parallel.py +1 -1
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/RECORD +28 -25
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import os
|
|
3
2
|
import json
|
|
4
3
|
import torch
|
|
5
4
|
import torch.distributed as dist
|
|
6
5
|
import math
|
|
7
6
|
from einops import rearrange
|
|
8
|
-
from enum import Enum
|
|
9
7
|
from typing import Callable, Dict, List, Tuple, Optional, Union
|
|
10
8
|
from tqdm import tqdm
|
|
11
9
|
from PIL import Image
|
|
12
|
-
from dataclasses import dataclass
|
|
13
10
|
from diffsynth_engine.models.flux import (
|
|
14
11
|
FluxTextEncoder1,
|
|
15
12
|
FluxTextEncoder2,
|
|
@@ -20,6 +17,7 @@ from diffsynth_engine.models.flux import (
|
|
|
20
17
|
flux_dit_config,
|
|
21
18
|
flux_text_encoder_config,
|
|
22
19
|
)
|
|
20
|
+
from diffsynth_engine.configs import FluxPipelineConfig, ControlType
|
|
23
21
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
24
22
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
25
23
|
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
|
|
@@ -416,48 +414,12 @@ def calculate_shift(
|
|
|
416
414
|
return mu
|
|
417
415
|
|
|
418
416
|
|
|
419
|
-
class ControlType(Enum):
|
|
420
|
-
normal = "normal"
|
|
421
|
-
bfl_control = "bfl_control"
|
|
422
|
-
bfl_fill = "bfl_fill"
|
|
423
|
-
bfl_kontext = "bfl_kontext"
|
|
424
|
-
|
|
425
|
-
def get_in_channel(self):
|
|
426
|
-
if self in [ControlType.normal, ControlType.bfl_kontext]:
|
|
427
|
-
return 64
|
|
428
|
-
elif self == ControlType.bfl_control:
|
|
429
|
-
return 128
|
|
430
|
-
elif self == ControlType.bfl_fill:
|
|
431
|
-
return 384
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
@dataclass
|
|
435
|
-
class FluxModelConfig:
|
|
436
|
-
dit_path: str | os.PathLike
|
|
437
|
-
clip_path: Optional[str | os.PathLike] = None
|
|
438
|
-
t5_path: Optional[str | os.PathLike] = None
|
|
439
|
-
vae_path: Optional[str | os.PathLike] = None
|
|
440
|
-
|
|
441
|
-
dit_dtype: torch.dtype = torch.bfloat16
|
|
442
|
-
clip_dtype: torch.dtype = torch.bfloat16
|
|
443
|
-
t5_dtype: torch.dtype = torch.bfloat16
|
|
444
|
-
vae_dtype: torch.dtype = torch.bfloat16
|
|
445
|
-
|
|
446
|
-
dit_attn_impl: Optional[str] = "auto"
|
|
447
|
-
use_fp8_linear: bool = False
|
|
448
|
-
|
|
449
|
-
sp_ulysses_degree: Optional[int] = None
|
|
450
|
-
sp_ring_degree: Optional[int] = None
|
|
451
|
-
tp_degree: Optional[int] = None
|
|
452
|
-
use_fsdp: bool = False
|
|
453
|
-
|
|
454
|
-
|
|
455
417
|
class FluxImagePipeline(BasePipeline):
|
|
456
418
|
lora_converter = FluxLoRAConverter()
|
|
457
419
|
|
|
458
420
|
def __init__(
|
|
459
421
|
self,
|
|
460
|
-
config:
|
|
422
|
+
config: FluxPipelineConfig,
|
|
461
423
|
tokenizer: CLIPTokenizer,
|
|
462
424
|
tokenizer_2: T5TokenizerFast,
|
|
463
425
|
text_encoder_1: FluxTextEncoder1,
|
|
@@ -465,23 +427,16 @@ class FluxImagePipeline(BasePipeline):
|
|
|
465
427
|
dit: Union[FluxDiT, FluxDiTFBCache],
|
|
466
428
|
vae_decoder: FluxVAEDecoder,
|
|
467
429
|
vae_encoder: FluxVAEEncoder,
|
|
468
|
-
load_text_encoder: bool = True,
|
|
469
|
-
batch_cfg: bool = False,
|
|
470
|
-
vae_tiled: bool = False,
|
|
471
|
-
vae_tile_size: int = 256,
|
|
472
|
-
vae_tile_stride: int = 256,
|
|
473
|
-
control_type: ControlType = ControlType.normal,
|
|
474
|
-
device: str = "cuda",
|
|
475
|
-
dtype: torch.dtype = torch.bfloat16,
|
|
476
430
|
):
|
|
477
431
|
super().__init__(
|
|
478
|
-
vae_tiled=vae_tiled,
|
|
479
|
-
vae_tile_size=vae_tile_size,
|
|
480
|
-
vae_tile_stride=vae_tile_stride,
|
|
481
|
-
device=device,
|
|
482
|
-
dtype=
|
|
432
|
+
vae_tiled=config.vae_tiled,
|
|
433
|
+
vae_tile_size=config.vae_tile_size,
|
|
434
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
435
|
+
device=config.device,
|
|
436
|
+
dtype=config.model_dtype,
|
|
483
437
|
)
|
|
484
438
|
self.config = config
|
|
439
|
+
# sampler
|
|
485
440
|
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
|
|
486
441
|
self.sampler = FlowMatchEulerSampler()
|
|
487
442
|
# models
|
|
@@ -492,11 +447,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
492
447
|
self.dit = dit
|
|
493
448
|
self.vae_decoder = vae_decoder
|
|
494
449
|
self.vae_encoder = vae_encoder
|
|
495
|
-
self.load_text_encoder = load_text_encoder
|
|
496
|
-
self.batch_cfg = batch_cfg
|
|
497
450
|
self.ip_adapter = None
|
|
498
451
|
self.redux = None
|
|
499
|
-
self.control_type = control_type
|
|
500
452
|
self.model_names = [
|
|
501
453
|
"text_encoder_1",
|
|
502
454
|
"text_encoder_2",
|
|
@@ -506,140 +458,115 @@ class FluxImagePipeline(BasePipeline):
|
|
|
506
458
|
]
|
|
507
459
|
|
|
508
460
|
@classmethod
|
|
509
|
-
def from_pretrained(
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
parallelism: int = 1,
|
|
522
|
-
use_cfg_parallel: bool = False,
|
|
523
|
-
use_fb_cache: bool = False,
|
|
524
|
-
fb_cache_relative_l1_threshold: float = 0.05,
|
|
525
|
-
) -> "FluxImagePipeline":
|
|
526
|
-
model_config = (
|
|
527
|
-
model_path_or_config
|
|
528
|
-
if isinstance(model_path_or_config, FluxModelConfig)
|
|
529
|
-
else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
|
|
530
|
-
)
|
|
531
|
-
if model_config.vae_path is None:
|
|
532
|
-
model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
|
|
533
|
-
|
|
534
|
-
if model_config.clip_path is None and load_text_encoder:
|
|
535
|
-
model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
|
|
536
|
-
if model_config.t5_path is None and load_text_encoder:
|
|
537
|
-
model_config.t5_path = fetch_model(
|
|
461
|
+
def from_pretrained(cls, model_path_or_config: str | FluxPipelineConfig) -> "FluxImagePipeline":
|
|
462
|
+
if isinstance(model_path_or_config, str):
|
|
463
|
+
config = FluxPipelineConfig(model_path=model_path_or_config)
|
|
464
|
+
else:
|
|
465
|
+
config = model_path_or_config
|
|
466
|
+
|
|
467
|
+
if config.vae_path is None:
|
|
468
|
+
config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
|
|
469
|
+
if config.clip_path is None and config.load_text_encoder:
|
|
470
|
+
config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
|
|
471
|
+
if config.t5_path is None and config.load_text_encoder:
|
|
472
|
+
config.t5_path = fetch_model(
|
|
538
473
|
"muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
|
|
539
474
|
)
|
|
540
475
|
|
|
541
|
-
logger.info(f"loading state dict from {
|
|
542
|
-
dit_state_dict = cls.load_model_checkpoint(
|
|
543
|
-
logger.info(f"loading state dict from {
|
|
544
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
545
|
-
if load_text_encoder:
|
|
546
|
-
logger.info(f"loading state dict from {
|
|
547
|
-
clip_state_dict = cls.load_model_checkpoint(
|
|
548
|
-
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
|
|
554
|
-
if load_text_encoder:
|
|
476
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
477
|
+
dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
478
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
479
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
480
|
+
if config.load_text_encoder:
|
|
481
|
+
logger.info(f"loading state dict from {config.clip_path} ...")
|
|
482
|
+
clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
|
|
483
|
+
logger.info(f"loading state dict from {config.t5_path} ...")
|
|
484
|
+
t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
|
|
485
|
+
|
|
486
|
+
init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
|
|
487
|
+
if config.load_text_encoder:
|
|
555
488
|
tokenizer = CLIPTokenizer.from_pretrained(FLUX_TOKENIZER_1_CONF_PATH)
|
|
556
489
|
tokenizer_2 = T5TokenizerFast.from_pretrained(FLUX_TOKENIZER_2_CONF_PATH)
|
|
557
490
|
with LoRAContext():
|
|
558
491
|
text_encoder_1 = FluxTextEncoder1.from_state_dict(
|
|
559
|
-
clip_state_dict, device=init_device, dtype=
|
|
492
|
+
clip_state_dict, device=init_device, dtype=config.clip_dtype
|
|
560
493
|
)
|
|
561
|
-
text_encoder_2 = FluxTextEncoder2.from_state_dict(
|
|
562
|
-
t5_state_dict, device=init_device, dtype=model_config.t5_dtype
|
|
563
|
-
)
|
|
494
|
+
text_encoder_2 = FluxTextEncoder2.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
|
|
564
495
|
|
|
565
|
-
vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=
|
|
566
|
-
vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=
|
|
496
|
+
vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
|
|
497
|
+
vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
|
|
567
498
|
|
|
568
499
|
with LoRAContext():
|
|
569
|
-
|
|
500
|
+
attn_kwargs = {
|
|
501
|
+
"attn_impl": config.dit_attn_impl,
|
|
502
|
+
"sparge_smooth_k": config.sparge_smooth_k,
|
|
503
|
+
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
504
|
+
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
505
|
+
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
506
|
+
}
|
|
507
|
+
if config.use_fbcache:
|
|
570
508
|
dit = FluxDiTFBCache.from_state_dict(
|
|
571
509
|
dit_state_dict,
|
|
572
510
|
device=init_device,
|
|
573
|
-
dtype=
|
|
574
|
-
in_channel=control_type.get_in_channel(),
|
|
575
|
-
|
|
576
|
-
relative_l1_threshold=
|
|
511
|
+
dtype=config.model_dtype,
|
|
512
|
+
in_channel=config.control_type.get_in_channel(),
|
|
513
|
+
attn_kwargs=attn_kwargs,
|
|
514
|
+
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
577
515
|
)
|
|
578
516
|
else:
|
|
579
517
|
dit = FluxDiT.from_state_dict(
|
|
580
518
|
dit_state_dict,
|
|
581
519
|
device=init_device,
|
|
582
|
-
dtype=
|
|
583
|
-
in_channel=control_type.get_in_channel(),
|
|
584
|
-
|
|
520
|
+
dtype=config.model_dtype,
|
|
521
|
+
in_channel=config.control_type.get_in_channel(),
|
|
522
|
+
attn_kwargs=attn_kwargs,
|
|
585
523
|
)
|
|
586
|
-
if
|
|
524
|
+
if config.use_fp8_linear:
|
|
587
525
|
enable_fp8_linear(dit)
|
|
588
526
|
|
|
589
527
|
pipe = cls(
|
|
590
|
-
config=
|
|
591
|
-
tokenizer=tokenizer if load_text_encoder else None,
|
|
592
|
-
tokenizer_2=tokenizer_2 if load_text_encoder else None,
|
|
593
|
-
text_encoder_1=text_encoder_1 if load_text_encoder else None,
|
|
594
|
-
text_encoder_2=text_encoder_2 if load_text_encoder else None,
|
|
528
|
+
config=config,
|
|
529
|
+
tokenizer=tokenizer if config.load_text_encoder else None,
|
|
530
|
+
tokenizer_2=tokenizer_2 if config.load_text_encoder else None,
|
|
531
|
+
text_encoder_1=text_encoder_1 if config.load_text_encoder else None,
|
|
532
|
+
text_encoder_2=text_encoder_2 if config.load_text_encoder else None,
|
|
595
533
|
dit=dit,
|
|
596
534
|
vae_decoder=vae_decoder,
|
|
597
535
|
vae_encoder=vae_encoder,
|
|
598
|
-
load_text_encoder=load_text_encoder,
|
|
599
|
-
batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
|
|
600
|
-
vae_tiled=vae_tiled,
|
|
601
|
-
vae_tile_size=vae_tile_size,
|
|
602
|
-
vae_tile_stride=vae_tile_stride,
|
|
603
|
-
control_type=control_type,
|
|
604
|
-
device=device,
|
|
605
|
-
dtype=model_config.dit_dtype,
|
|
606
536
|
)
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
if
|
|
610
|
-
pipe.
|
|
537
|
+
pipe.eval()
|
|
538
|
+
|
|
539
|
+
if config.offload_mode is not None:
|
|
540
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
541
|
+
|
|
542
|
+
if config.model_dtype == torch.float8_e4m3fn:
|
|
543
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
611
544
|
pipe.enable_fp8_autocast(
|
|
612
|
-
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=
|
|
545
|
+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
613
546
|
)
|
|
614
547
|
|
|
615
|
-
if
|
|
616
|
-
pipe.dtype = torch.bfloat16 #
|
|
548
|
+
if config.t5_dtype == torch.float8_e4m3fn:
|
|
549
|
+
pipe.dtype = torch.bfloat16 # compute dtype
|
|
617
550
|
pipe.enable_fp8_autocast(
|
|
618
|
-
model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=
|
|
551
|
+
model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
619
552
|
)
|
|
620
553
|
|
|
621
|
-
if parallelism > 1:
|
|
622
|
-
parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
|
|
623
|
-
cfg_degree = parallel_config["cfg_degree"]
|
|
624
|
-
sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
|
|
625
|
-
sp_ring_degree = parallel_config["sp_ring_degree"]
|
|
626
|
-
tp_degree = parallel_config["tp_degree"]
|
|
627
|
-
use_fsdp = parallel_config["use_fsdp"]
|
|
554
|
+
if config.parallelism > 1:
|
|
628
555
|
return ParallelWrapper(
|
|
629
556
|
pipe,
|
|
630
|
-
cfg_degree=cfg_degree,
|
|
631
|
-
sp_ulysses_degree=sp_ulysses_degree,
|
|
632
|
-
sp_ring_degree=sp_ring_degree,
|
|
633
|
-
tp_degree=tp_degree,
|
|
634
|
-
use_fsdp=use_fsdp,
|
|
557
|
+
cfg_degree=config.cfg_degree,
|
|
558
|
+
sp_ulysses_degree=config.sp_ulysses_degree,
|
|
559
|
+
sp_ring_degree=config.sp_ring_degree,
|
|
560
|
+
tp_degree=config.tp_degree,
|
|
561
|
+
use_fsdp=config.use_fsdp,
|
|
635
562
|
device="cuda",
|
|
636
563
|
)
|
|
637
564
|
return pipe
|
|
638
565
|
|
|
639
566
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
640
|
-
assert self.config.tp_degree is None, (
|
|
567
|
+
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
641
568
|
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
642
|
-
"set tp_degree=None during pipeline initialization"
|
|
569
|
+
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
643
570
|
)
|
|
644
571
|
assert not (self.config.use_fsdp and fused), (
|
|
645
572
|
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
@@ -791,9 +718,9 @@ class FluxImagePipeline(BasePipeline):
|
|
|
791
718
|
total_step: int,
|
|
792
719
|
):
|
|
793
720
|
origin_latents_shape = latents.shape
|
|
794
|
-
if self.control_type != ControlType.normal:
|
|
721
|
+
if self.config.control_type != ControlType.normal:
|
|
795
722
|
controlnet_param = controlnet_params[0]
|
|
796
|
-
if self.control_type == ControlType.bfl_kontext:
|
|
723
|
+
if self.config.control_type == ControlType.bfl_kontext:
|
|
797
724
|
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
|
|
798
725
|
image_ids = image_ids.repeat(1, 2, 1)
|
|
799
726
|
image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
|
|
@@ -828,7 +755,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
828
755
|
controlnet_double_block_output=double_block_output,
|
|
829
756
|
controlnet_single_block_output=single_block_output,
|
|
830
757
|
)
|
|
831
|
-
if self.control_type == ControlType.bfl_kontext:
|
|
758
|
+
if self.config.control_type == ControlType.bfl_kontext:
|
|
832
759
|
noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
|
|
833
760
|
return noise_pred
|
|
834
761
|
|
|
@@ -877,7 +804,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
877
804
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
878
805
|
latent = self.encode_image(image)
|
|
879
806
|
else:
|
|
880
|
-
if self.control_type == ControlType.normal:
|
|
807
|
+
if self.config.control_type == ControlType.normal:
|
|
881
808
|
image = image.resize((width, height))
|
|
882
809
|
mask = mask.resize((width, height))
|
|
883
810
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
@@ -888,7 +815,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
888
815
|
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
|
|
889
816
|
mask = 1 - mask
|
|
890
817
|
latent = torch.cat([latent, mask], dim=1)
|
|
891
|
-
elif self.control_type == ControlType.bfl_fill:
|
|
818
|
+
elif self.config.control_type == ControlType.bfl_fill:
|
|
892
819
|
image = image.resize((width, height))
|
|
893
820
|
mask = mask.resize((width, height))
|
|
894
821
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
|
|
@@ -898,7 +825,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
898
825
|
mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
|
899
826
|
latent = torch.cat((image, mask), dim=1)
|
|
900
827
|
else:
|
|
901
|
-
raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.control_type}")
|
|
828
|
+
raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.config.control_type}")
|
|
902
829
|
return latent
|
|
903
830
|
|
|
904
831
|
def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
|
|
@@ -996,7 +923,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
996
923
|
self.dit.refresh_cache_status(num_inference_steps)
|
|
997
924
|
if not isinstance(controlnet_params, list):
|
|
998
925
|
controlnet_params = [controlnet_params]
|
|
999
|
-
if self.control_type != ControlType.normal:
|
|
926
|
+
if self.config.control_type != ControlType.normal:
|
|
1000
927
|
assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
|
|
1001
928
|
|
|
1002
929
|
if input_image is not None:
|
|
@@ -1033,7 +960,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
1033
960
|
if self.ip_adapter is not None and self.redux is not None:
|
|
1034
961
|
raise Exception("ip-adapter and flux redux cannot be used at the same time")
|
|
1035
962
|
elif self.ip_adapter is not None:
|
|
1036
|
-
image_emb = self.ip_adapter(ref_image)
|
|
963
|
+
image_emb = self.ip_adapter.encode_image(ref_image)
|
|
1037
964
|
elif self.redux is not None:
|
|
1038
965
|
image_prompt_embeds = self.redux(ref_image)
|
|
1039
966
|
positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
|
|
@@ -1063,7 +990,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
1063
990
|
controlnet_params=controlnet_params,
|
|
1064
991
|
current_step=i,
|
|
1065
992
|
total_step=len(timesteps),
|
|
1066
|
-
batch_cfg=self.batch_cfg,
|
|
993
|
+
batch_cfg=self.config.batch_cfg,
|
|
1067
994
|
)
|
|
1068
995
|
# Denoise
|
|
1069
996
|
latents = self.sampler.step(latents, noise_pred, i)
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import os
|
|
3
2
|
import torch
|
|
4
3
|
import numpy as np
|
|
5
4
|
from einops import repeat
|
|
6
|
-
from dataclasses import dataclass
|
|
7
5
|
from typing import Callable, Dict, Optional, List
|
|
8
6
|
from tqdm import tqdm
|
|
9
7
|
from PIL import Image, ImageOps
|
|
10
8
|
|
|
9
|
+
from diffsynth_engine.configs import SDPipelineConfig
|
|
11
10
|
from diffsynth_engine.models.base import split_suffix
|
|
12
11
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
13
12
|
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
|
|
@@ -84,17 +83,6 @@ def convert_diffusers_name_to_compvis(key):
|
|
|
84
83
|
return key
|
|
85
84
|
|
|
86
85
|
|
|
87
|
-
@dataclass
|
|
88
|
-
class SDModelConfig:
|
|
89
|
-
unet_path: str | os.PathLike
|
|
90
|
-
clip_path: Optional[str | os.PathLike] = None
|
|
91
|
-
vae_path: Optional[str | os.PathLike] = None
|
|
92
|
-
|
|
93
|
-
unet_dtype: torch.dtype = torch.float16
|
|
94
|
-
clip_dtype: torch.dtype = torch.float16
|
|
95
|
-
vae_dtype: torch.dtype = torch.float32
|
|
96
|
-
|
|
97
|
-
|
|
98
86
|
class SDLoRAConverter(LoRAStateDictConverter):
|
|
99
87
|
def _replace_kohya_te_key(self, key):
|
|
100
88
|
key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
|
|
@@ -151,27 +139,22 @@ class SDImagePipeline(BasePipeline):
|
|
|
151
139
|
|
|
152
140
|
def __init__(
|
|
153
141
|
self,
|
|
154
|
-
config:
|
|
142
|
+
config: SDPipelineConfig,
|
|
155
143
|
tokenizer: CLIPTokenizer,
|
|
156
144
|
text_encoder: SDTextEncoder,
|
|
157
145
|
unet: SDUNet,
|
|
158
146
|
vae_decoder: SDVAEDecoder,
|
|
159
147
|
vae_encoder: SDVAEEncoder,
|
|
160
|
-
batch_cfg: bool = True,
|
|
161
|
-
vae_tiled: bool = False,
|
|
162
|
-
vae_tile_size: int = 256,
|
|
163
|
-
vae_tile_stride: int = 256,
|
|
164
|
-
device: str = "cuda",
|
|
165
|
-
dtype: torch.dtype = torch.float16,
|
|
166
148
|
):
|
|
167
149
|
super().__init__(
|
|
168
|
-
vae_tiled=vae_tiled,
|
|
169
|
-
vae_tile_size=vae_tile_size,
|
|
170
|
-
vae_tile_stride=vae_tile_stride,
|
|
171
|
-
device=device,
|
|
172
|
-
dtype=
|
|
150
|
+
vae_tiled=config.vae_tiled,
|
|
151
|
+
vae_tile_size=config.vae_tile_size,
|
|
152
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
153
|
+
device=config.device,
|
|
154
|
+
dtype=config.model_dtype,
|
|
173
155
|
)
|
|
174
156
|
self.config = config
|
|
157
|
+
# sampler
|
|
175
158
|
self.noise_scheduler = ScaledLinearScheduler()
|
|
176
159
|
self.sampler = EulerSampler()
|
|
177
160
|
# models
|
|
@@ -180,71 +163,54 @@ class SDImagePipeline(BasePipeline):
|
|
|
180
163
|
self.unet = unet
|
|
181
164
|
self.vae_decoder = vae_decoder
|
|
182
165
|
self.vae_encoder = vae_encoder
|
|
183
|
-
self.batch_cfg = batch_cfg
|
|
184
166
|
self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
|
185
167
|
|
|
186
168
|
@classmethod
|
|
187
|
-
def from_pretrained(
|
|
188
|
-
cls,
|
|
189
|
-
model_path_or_config: str | os.PathLike | SDModelConfig,
|
|
190
|
-
batch_cfg: bool = True,
|
|
191
|
-
vae_tiled: bool = False,
|
|
192
|
-
vae_tile_size: int = 256,
|
|
193
|
-
vae_tile_stride: int = 256,
|
|
194
|
-
device: str = "cuda",
|
|
195
|
-
dtype: torch.dtype = torch.float16,
|
|
196
|
-
offload_mode: str | None = None,
|
|
197
|
-
) -> "SDImagePipeline":
|
|
169
|
+
def from_pretrained(cls, model_path_or_config: SDPipelineConfig) -> "SDImagePipeline":
|
|
198
170
|
if isinstance(model_path_or_config, str):
|
|
199
|
-
|
|
171
|
+
config = SDPipelineConfig(model_path=model_path_or_config)
|
|
200
172
|
else:
|
|
201
|
-
|
|
173
|
+
config = model_path_or_config
|
|
202
174
|
|
|
203
|
-
logger.info(f"loading state dict from {
|
|
204
|
-
unet_state_dict = cls.load_model_checkpoint(
|
|
175
|
+
logger.info(f"loading state dict from {config.model_path} ...")
|
|
176
|
+
unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
205
177
|
|
|
206
|
-
if
|
|
207
|
-
logger.info(f"loading state dict from {
|
|
208
|
-
vae_state_dict = cls.load_model_checkpoint(
|
|
178
|
+
if config.vae_path is not None:
|
|
179
|
+
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
180
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
209
181
|
else:
|
|
210
182
|
vae_state_dict = unet_state_dict
|
|
211
183
|
|
|
212
|
-
if
|
|
213
|
-
logger.info(f"loading state dict from {
|
|
214
|
-
clip_state_dict = cls.load_model_checkpoint(
|
|
184
|
+
if config.clip_path is not None:
|
|
185
|
+
logger.info(f"loading state dict from {config.clip_path} ...")
|
|
186
|
+
clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
|
|
215
187
|
else:
|
|
216
188
|
clip_state_dict = unet_state_dict
|
|
217
189
|
|
|
218
|
-
init_device = "cpu" if offload_mode else device
|
|
190
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
219
191
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
220
192
|
with LoRAContext():
|
|
221
|
-
text_encoder = SDTextEncoder.from_state_dict(
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
|
|
193
|
+
text_encoder = SDTextEncoder.from_state_dict(clip_state_dict, device=init_device, dtype=config.clip_dtype)
|
|
194
|
+
unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
|
|
225
195
|
vae_decoder = SDVAEDecoder.from_state_dict(
|
|
226
|
-
vae_state_dict, device=init_device, dtype=
|
|
196
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
227
197
|
)
|
|
228
198
|
vae_encoder = SDVAEEncoder.from_state_dict(
|
|
229
|
-
vae_state_dict, device=init_device, dtype=
|
|
199
|
+
vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
|
|
230
200
|
)
|
|
231
201
|
|
|
232
202
|
pipe = cls(
|
|
233
|
-
config=
|
|
203
|
+
config=config,
|
|
234
204
|
tokenizer=tokenizer,
|
|
235
205
|
text_encoder=text_encoder,
|
|
236
206
|
unet=unet,
|
|
237
207
|
vae_decoder=vae_decoder,
|
|
238
208
|
vae_encoder=vae_encoder,
|
|
239
|
-
batch_cfg=batch_cfg,
|
|
240
|
-
vae_tiled=vae_tiled,
|
|
241
|
-
vae_tile_size=vae_tile_size,
|
|
242
|
-
vae_tile_stride=vae_tile_stride,
|
|
243
|
-
device=device,
|
|
244
|
-
dtype=dtype,
|
|
245
209
|
)
|
|
246
|
-
|
|
247
|
-
|
|
210
|
+
pipe.eval()
|
|
211
|
+
|
|
212
|
+
if config.offload_mode is not None:
|
|
213
|
+
pipe.enable_cpu_offload(config.offload_mode)
|
|
248
214
|
return pipe
|
|
249
215
|
|
|
250
216
|
@classmethod
|
|
@@ -439,7 +405,7 @@ class SDImagePipeline(BasePipeline):
|
|
|
439
405
|
controlnet_params=controlnet_params,
|
|
440
406
|
current_step=i,
|
|
441
407
|
total_step=len(timesteps),
|
|
442
|
-
batch_cfg=self.batch_cfg,
|
|
408
|
+
batch_cfg=self.config.batch_cfg,
|
|
443
409
|
)
|
|
444
410
|
# Denoise
|
|
445
411
|
latents = self.sampler.step(latents, noise_pred, i)
|