diffsynth-engine 0.3.6.dev9__py3-none-any.whl → 0.3.6.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. diffsynth_engine/__init__.py +10 -8
  2. diffsynth_engine/configs/__init__.py +23 -0
  3. diffsynth_engine/configs/controlnet.py +17 -0
  4. diffsynth_engine/configs/pipeline.py +206 -0
  5. diffsynth_engine/models/basic/attention.py +43 -4
  6. diffsynth_engine/models/flux/flux_controlnet.py +8 -5
  7. diffsynth_engine/models/flux/flux_dit.py +22 -16
  8. diffsynth_engine/models/flux/flux_dit_fbcache.py +7 -7
  9. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  10. diffsynth_engine/models/sd/sd_controlnet.py +2 -4
  11. diffsynth_engine/models/sdxl/sdxl_controlnet.py +1 -2
  12. diffsynth_engine/models/wan/wan_dit.py +15 -15
  13. diffsynth_engine/pipelines/__init__.py +5 -8
  14. diffsynth_engine/pipelines/base.py +14 -65
  15. diffsynth_engine/pipelines/flux_image.py +85 -158
  16. diffsynth_engine/pipelines/sd_image.py +30 -64
  17. diffsynth_engine/pipelines/sdxl_image.py +39 -71
  18. diffsynth_engine/pipelines/wan_video.py +66 -105
  19. diffsynth_engine/tools/flux_inpainting_tool.py +7 -3
  20. diffsynth_engine/tools/flux_outpainting_tool.py +7 -3
  21. diffsynth_engine/tools/flux_reference_tool.py +21 -5
  22. diffsynth_engine/tools/flux_replace_tool.py +15 -3
  23. diffsynth_engine/utils/parallel.py +1 -1
  24. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/METADATA +1 -1
  25. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/RECORD +28 -25
  26. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/WHEEL +0 -0
  27. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/licenses/LICENSE +0 -0
  28. {diffsynth_engine-0.3.6.dev9.dist-info → diffsynth_engine-0.3.6.dev11.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,12 @@
1
1
  import re
2
- import os
3
2
  import json
4
3
  import torch
5
4
  import torch.distributed as dist
6
5
  import math
7
6
  from einops import rearrange
8
- from enum import Enum
9
7
  from typing import Callable, Dict, List, Tuple, Optional, Union
10
8
  from tqdm import tqdm
11
9
  from PIL import Image
12
- from dataclasses import dataclass
13
10
  from diffsynth_engine.models.flux import (
14
11
  FluxTextEncoder1,
15
12
  FluxTextEncoder2,
@@ -20,6 +17,7 @@ from diffsynth_engine.models.flux import (
20
17
  flux_dit_config,
21
18
  flux_text_encoder_config,
22
19
  )
20
+ from diffsynth_engine.configs import FluxPipelineConfig, ControlType
23
21
  from diffsynth_engine.models.basic.lora import LoRAContext
24
22
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
25
23
  from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
@@ -416,48 +414,12 @@ def calculate_shift(
416
414
  return mu
417
415
 
418
416
 
419
- class ControlType(Enum):
420
- normal = "normal"
421
- bfl_control = "bfl_control"
422
- bfl_fill = "bfl_fill"
423
- bfl_kontext = "bfl_kontext"
424
-
425
- def get_in_channel(self):
426
- if self in [ControlType.normal, ControlType.bfl_kontext]:
427
- return 64
428
- elif self == ControlType.bfl_control:
429
- return 128
430
- elif self == ControlType.bfl_fill:
431
- return 384
432
-
433
-
434
- @dataclass
435
- class FluxModelConfig:
436
- dit_path: str | os.PathLike
437
- clip_path: Optional[str | os.PathLike] = None
438
- t5_path: Optional[str | os.PathLike] = None
439
- vae_path: Optional[str | os.PathLike] = None
440
-
441
- dit_dtype: torch.dtype = torch.bfloat16
442
- clip_dtype: torch.dtype = torch.bfloat16
443
- t5_dtype: torch.dtype = torch.bfloat16
444
- vae_dtype: torch.dtype = torch.bfloat16
445
-
446
- dit_attn_impl: Optional[str] = "auto"
447
- use_fp8_linear: bool = False
448
-
449
- sp_ulysses_degree: Optional[int] = None
450
- sp_ring_degree: Optional[int] = None
451
- tp_degree: Optional[int] = None
452
- use_fsdp: bool = False
453
-
454
-
455
417
  class FluxImagePipeline(BasePipeline):
456
418
  lora_converter = FluxLoRAConverter()
457
419
 
458
420
  def __init__(
459
421
  self,
460
- config: FluxModelConfig,
422
+ config: FluxPipelineConfig,
461
423
  tokenizer: CLIPTokenizer,
462
424
  tokenizer_2: T5TokenizerFast,
463
425
  text_encoder_1: FluxTextEncoder1,
@@ -465,23 +427,16 @@ class FluxImagePipeline(BasePipeline):
465
427
  dit: Union[FluxDiT, FluxDiTFBCache],
466
428
  vae_decoder: FluxVAEDecoder,
467
429
  vae_encoder: FluxVAEEncoder,
468
- load_text_encoder: bool = True,
469
- batch_cfg: bool = False,
470
- vae_tiled: bool = False,
471
- vae_tile_size: int = 256,
472
- vae_tile_stride: int = 256,
473
- control_type: ControlType = ControlType.normal,
474
- device: str = "cuda",
475
- dtype: torch.dtype = torch.bfloat16,
476
430
  ):
477
431
  super().__init__(
478
- vae_tiled=vae_tiled,
479
- vae_tile_size=vae_tile_size,
480
- vae_tile_stride=vae_tile_stride,
481
- device=device,
482
- dtype=dtype,
432
+ vae_tiled=config.vae_tiled,
433
+ vae_tile_size=config.vae_tile_size,
434
+ vae_tile_stride=config.vae_tile_stride,
435
+ device=config.device,
436
+ dtype=config.model_dtype,
483
437
  )
484
438
  self.config = config
439
+ # sampler
485
440
  self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
486
441
  self.sampler = FlowMatchEulerSampler()
487
442
  # models
@@ -492,11 +447,8 @@ class FluxImagePipeline(BasePipeline):
492
447
  self.dit = dit
493
448
  self.vae_decoder = vae_decoder
494
449
  self.vae_encoder = vae_encoder
495
- self.load_text_encoder = load_text_encoder
496
- self.batch_cfg = batch_cfg
497
450
  self.ip_adapter = None
498
451
  self.redux = None
499
- self.control_type = control_type
500
452
  self.model_names = [
501
453
  "text_encoder_1",
502
454
  "text_encoder_2",
@@ -506,140 +458,115 @@ class FluxImagePipeline(BasePipeline):
506
458
  ]
507
459
 
508
460
  @classmethod
509
- def from_pretrained(
510
- cls,
511
- model_path_or_config: str | os.PathLike | FluxModelConfig,
512
- load_text_encoder: bool = True,
513
- batch_cfg: bool = False,
514
- vae_tiled: bool = False,
515
- vae_tile_size: int = 256,
516
- vae_tile_stride: int = 256,
517
- control_type: ControlType = ControlType.normal,
518
- device: str = "cuda",
519
- dtype: torch.dtype = torch.bfloat16,
520
- offload_mode: str | None = None,
521
- parallelism: int = 1,
522
- use_cfg_parallel: bool = False,
523
- use_fb_cache: bool = False,
524
- fb_cache_relative_l1_threshold: float = 0.05,
525
- ) -> "FluxImagePipeline":
526
- model_config = (
527
- model_path_or_config
528
- if isinstance(model_path_or_config, FluxModelConfig)
529
- else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
530
- )
531
- if model_config.vae_path is None:
532
- model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
533
-
534
- if model_config.clip_path is None and load_text_encoder:
535
- model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
536
- if model_config.t5_path is None and load_text_encoder:
537
- model_config.t5_path = fetch_model(
461
+ def from_pretrained(cls, model_path_or_config: str | FluxPipelineConfig) -> "FluxImagePipeline":
462
+ if isinstance(model_path_or_config, str):
463
+ config = FluxPipelineConfig(model_path=model_path_or_config)
464
+ else:
465
+ config = model_path_or_config
466
+
467
+ if config.vae_path is None:
468
+ config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
469
+ if config.clip_path is None and config.load_text_encoder:
470
+ config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
471
+ if config.t5_path is None and config.load_text_encoder:
472
+ config.t5_path = fetch_model(
538
473
  "muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
539
474
  )
540
475
 
541
- logger.info(f"loading state dict from {model_config.dit_path} ...")
542
- dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=model_config.dit_dtype)
543
- logger.info(f"loading state dict from {model_config.vae_path} ...")
544
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
545
- if load_text_encoder:
546
- logger.info(f"loading state dict from {model_config.clip_path} ...")
547
- clip_state_dict = cls.load_model_checkpoint(
548
- model_config.clip_path, device="cpu", dtype=model_config.clip_dtype
549
- )
550
- logger.info(f"loading state dict from {model_config.t5_path} ...")
551
- t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
552
-
553
- init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
554
- if load_text_encoder:
476
+ logger.info(f"loading state dict from {config.model_path} ...")
477
+ dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
478
+ logger.info(f"loading state dict from {config.vae_path} ...")
479
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
480
+ if config.load_text_encoder:
481
+ logger.info(f"loading state dict from {config.clip_path} ...")
482
+ clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
483
+ logger.info(f"loading state dict from {config.t5_path} ...")
484
+ t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype)
485
+
486
+ init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
487
+ if config.load_text_encoder:
555
488
  tokenizer = CLIPTokenizer.from_pretrained(FLUX_TOKENIZER_1_CONF_PATH)
556
489
  tokenizer_2 = T5TokenizerFast.from_pretrained(FLUX_TOKENIZER_2_CONF_PATH)
557
490
  with LoRAContext():
558
491
  text_encoder_1 = FluxTextEncoder1.from_state_dict(
559
- clip_state_dict, device=init_device, dtype=model_config.clip_dtype
492
+ clip_state_dict, device=init_device, dtype=config.clip_dtype
560
493
  )
561
- text_encoder_2 = FluxTextEncoder2.from_state_dict(
562
- t5_state_dict, device=init_device, dtype=model_config.t5_dtype
563
- )
494
+ text_encoder_2 = FluxTextEncoder2.from_state_dict(t5_state_dict, device=init_device, dtype=config.t5_dtype)
564
495
 
565
- vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
566
- vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
496
+ vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
497
+ vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=config.vae_dtype)
567
498
 
568
499
  with LoRAContext():
569
- if use_fb_cache:
500
+ attn_kwargs = {
501
+ "attn_impl": config.dit_attn_impl,
502
+ "sparge_smooth_k": config.sparge_smooth_k,
503
+ "sparge_cdfthreshd": config.sparge_cdfthreshd,
504
+ "sparge_simthreshd1": config.sparge_simthreshd1,
505
+ "sparge_pvthreshd": config.sparge_pvthreshd,
506
+ }
507
+ if config.use_fbcache:
570
508
  dit = FluxDiTFBCache.from_state_dict(
571
509
  dit_state_dict,
572
510
  device=init_device,
573
- dtype=model_config.dit_dtype,
574
- in_channel=control_type.get_in_channel(),
575
- attn_impl=model_config.dit_attn_impl,
576
- relative_l1_threshold=fb_cache_relative_l1_threshold,
511
+ dtype=config.model_dtype,
512
+ in_channel=config.control_type.get_in_channel(),
513
+ attn_kwargs=attn_kwargs,
514
+ relative_l1_threshold=config.fbcache_relative_l1_threshold,
577
515
  )
578
516
  else:
579
517
  dit = FluxDiT.from_state_dict(
580
518
  dit_state_dict,
581
519
  device=init_device,
582
- dtype=model_config.dit_dtype,
583
- in_channel=control_type.get_in_channel(),
584
- attn_impl=model_config.dit_attn_impl,
520
+ dtype=config.model_dtype,
521
+ in_channel=config.control_type.get_in_channel(),
522
+ attn_kwargs=attn_kwargs,
585
523
  )
586
- if model_config.use_fp8_linear:
524
+ if config.use_fp8_linear:
587
525
  enable_fp8_linear(dit)
588
526
 
589
527
  pipe = cls(
590
- config=model_config,
591
- tokenizer=tokenizer if load_text_encoder else None,
592
- tokenizer_2=tokenizer_2 if load_text_encoder else None,
593
- text_encoder_1=text_encoder_1 if load_text_encoder else None,
594
- text_encoder_2=text_encoder_2 if load_text_encoder else None,
528
+ config=config,
529
+ tokenizer=tokenizer if config.load_text_encoder else None,
530
+ tokenizer_2=tokenizer_2 if config.load_text_encoder else None,
531
+ text_encoder_1=text_encoder_1 if config.load_text_encoder else None,
532
+ text_encoder_2=text_encoder_2 if config.load_text_encoder else None,
595
533
  dit=dit,
596
534
  vae_decoder=vae_decoder,
597
535
  vae_encoder=vae_encoder,
598
- load_text_encoder=load_text_encoder,
599
- batch_cfg=True if parallelism > 1 and use_cfg_parallel else batch_cfg,
600
- vae_tiled=vae_tiled,
601
- vae_tile_size=vae_tile_size,
602
- vae_tile_stride=vae_tile_stride,
603
- control_type=control_type,
604
- device=device,
605
- dtype=model_config.dit_dtype,
606
536
  )
607
- if offload_mode is not None:
608
- pipe.enable_cpu_offload(offload_mode)
609
- if model_config.dit_dtype == torch.float8_e4m3fn:
610
- pipe.dtype = torch.bfloat16 # running dtype
537
+ pipe.eval()
538
+
539
+ if config.offload_mode is not None:
540
+ pipe.enable_cpu_offload(config.offload_mode)
541
+
542
+ if config.model_dtype == torch.float8_e4m3fn:
543
+ pipe.dtype = torch.bfloat16 # compute dtype
611
544
  pipe.enable_fp8_autocast(
612
- model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
545
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
613
546
  )
614
547
 
615
- if model_config.t5_dtype == torch.float8_e4m3fn:
616
- pipe.dtype = torch.bfloat16 # running dtype
548
+ if config.t5_dtype == torch.float8_e4m3fn:
549
+ pipe.dtype = torch.bfloat16 # compute dtype
617
550
  pipe.enable_fp8_autocast(
618
- model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
551
+ model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
619
552
  )
620
553
 
621
- if parallelism > 1:
622
- parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
623
- cfg_degree = parallel_config["cfg_degree"]
624
- sp_ulysses_degree = parallel_config["sp_ulysses_degree"]
625
- sp_ring_degree = parallel_config["sp_ring_degree"]
626
- tp_degree = parallel_config["tp_degree"]
627
- use_fsdp = parallel_config["use_fsdp"]
554
+ if config.parallelism > 1:
628
555
  return ParallelWrapper(
629
556
  pipe,
630
- cfg_degree=cfg_degree,
631
- sp_ulysses_degree=sp_ulysses_degree,
632
- sp_ring_degree=sp_ring_degree,
633
- tp_degree=tp_degree,
634
- use_fsdp=use_fsdp,
557
+ cfg_degree=config.cfg_degree,
558
+ sp_ulysses_degree=config.sp_ulysses_degree,
559
+ sp_ring_degree=config.sp_ring_degree,
560
+ tp_degree=config.tp_degree,
561
+ use_fsdp=config.use_fsdp,
635
562
  device="cuda",
636
563
  )
637
564
  return pipe
638
565
 
639
566
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
640
- assert self.config.tp_degree is None, (
567
+ assert self.config.tp_degree is None or self.config.tp_degree == 1, (
641
568
  "load LoRA is not allowed when tensor parallel is enabled; "
642
- "set tp_degree=None during pipeline initialization"
569
+ "set tp_degree=None or tp_degree=1 during pipeline initialization"
643
570
  )
644
571
  assert not (self.config.use_fsdp and fused), (
645
572
  "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
@@ -791,9 +718,9 @@ class FluxImagePipeline(BasePipeline):
791
718
  total_step: int,
792
719
  ):
793
720
  origin_latents_shape = latents.shape
794
- if self.control_type != ControlType.normal:
721
+ if self.config.control_type != ControlType.normal:
795
722
  controlnet_param = controlnet_params[0]
796
- if self.control_type == ControlType.bfl_kontext:
723
+ if self.config.control_type == ControlType.bfl_kontext:
797
724
  latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
798
725
  image_ids = image_ids.repeat(1, 2, 1)
799
726
  image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
@@ -828,7 +755,7 @@ class FluxImagePipeline(BasePipeline):
828
755
  controlnet_double_block_output=double_block_output,
829
756
  controlnet_single_block_output=single_block_output,
830
757
  )
831
- if self.control_type == ControlType.bfl_kontext:
758
+ if self.config.control_type == ControlType.bfl_kontext:
832
759
  noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
833
760
  return noise_pred
834
761
 
@@ -877,7 +804,7 @@ class FluxImagePipeline(BasePipeline):
877
804
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
878
805
  latent = self.encode_image(image)
879
806
  else:
880
- if self.control_type == ControlType.normal:
807
+ if self.config.control_type == ControlType.normal:
881
808
  image = image.resize((width, height))
882
809
  mask = mask.resize((width, height))
883
810
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
@@ -888,7 +815,7 @@ class FluxImagePipeline(BasePipeline):
888
815
  mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
889
816
  mask = 1 - mask
890
817
  latent = torch.cat([latent, mask], dim=1)
891
- elif self.control_type == ControlType.bfl_fill:
818
+ elif self.config.control_type == ControlType.bfl_fill:
892
819
  image = image.resize((width, height))
893
820
  mask = mask.resize((width, height))
894
821
  image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
@@ -898,7 +825,7 @@ class FluxImagePipeline(BasePipeline):
898
825
  mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
899
826
  latent = torch.cat((image, mask), dim=1)
900
827
  else:
901
- raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.control_type}")
828
+ raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.config.control_type}")
902
829
  return latent
903
830
 
904
831
  def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
@@ -996,7 +923,7 @@ class FluxImagePipeline(BasePipeline):
996
923
  self.dit.refresh_cache_status(num_inference_steps)
997
924
  if not isinstance(controlnet_params, list):
998
925
  controlnet_params = [controlnet_params]
999
- if self.control_type != ControlType.normal:
926
+ if self.config.control_type != ControlType.normal:
1000
927
  assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
1001
928
 
1002
929
  if input_image is not None:
@@ -1033,7 +960,7 @@ class FluxImagePipeline(BasePipeline):
1033
960
  if self.ip_adapter is not None and self.redux is not None:
1034
961
  raise Exception("ip-adapter and flux redux cannot be used at the same time")
1035
962
  elif self.ip_adapter is not None:
1036
- image_emb = self.ip_adapter(ref_image)
963
+ image_emb = self.ip_adapter.encode_image(ref_image)
1037
964
  elif self.redux is not None:
1038
965
  image_prompt_embeds = self.redux(ref_image)
1039
966
  positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
@@ -1063,7 +990,7 @@ class FluxImagePipeline(BasePipeline):
1063
990
  controlnet_params=controlnet_params,
1064
991
  current_step=i,
1065
992
  total_step=len(timesteps),
1066
- batch_cfg=self.batch_cfg,
993
+ batch_cfg=self.config.batch_cfg,
1067
994
  )
1068
995
  # Denoise
1069
996
  latents = self.sampler.step(latents, noise_pred, i)
@@ -1,13 +1,12 @@
1
1
  import re
2
- import os
3
2
  import torch
4
3
  import numpy as np
5
4
  from einops import repeat
6
- from dataclasses import dataclass
7
5
  from typing import Callable, Dict, Optional, List
8
6
  from tqdm import tqdm
9
7
  from PIL import Image, ImageOps
10
8
 
9
+ from diffsynth_engine.configs import SDPipelineConfig
11
10
  from diffsynth_engine.models.base import split_suffix
12
11
  from diffsynth_engine.models.basic.lora import LoRAContext
13
12
  from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
@@ -84,17 +83,6 @@ def convert_diffusers_name_to_compvis(key):
84
83
  return key
85
84
 
86
85
 
87
- @dataclass
88
- class SDModelConfig:
89
- unet_path: str | os.PathLike
90
- clip_path: Optional[str | os.PathLike] = None
91
- vae_path: Optional[str | os.PathLike] = None
92
-
93
- unet_dtype: torch.dtype = torch.float16
94
- clip_dtype: torch.dtype = torch.float16
95
- vae_dtype: torch.dtype = torch.float32
96
-
97
-
98
86
  class SDLoRAConverter(LoRAStateDictConverter):
99
87
  def _replace_kohya_te_key(self, key):
100
88
  key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
@@ -151,27 +139,22 @@ class SDImagePipeline(BasePipeline):
151
139
 
152
140
  def __init__(
153
141
  self,
154
- config: SDModelConfig,
142
+ config: SDPipelineConfig,
155
143
  tokenizer: CLIPTokenizer,
156
144
  text_encoder: SDTextEncoder,
157
145
  unet: SDUNet,
158
146
  vae_decoder: SDVAEDecoder,
159
147
  vae_encoder: SDVAEEncoder,
160
- batch_cfg: bool = True,
161
- vae_tiled: bool = False,
162
- vae_tile_size: int = 256,
163
- vae_tile_stride: int = 256,
164
- device: str = "cuda",
165
- dtype: torch.dtype = torch.float16,
166
148
  ):
167
149
  super().__init__(
168
- vae_tiled=vae_tiled,
169
- vae_tile_size=vae_tile_size,
170
- vae_tile_stride=vae_tile_stride,
171
- device=device,
172
- dtype=dtype,
150
+ vae_tiled=config.vae_tiled,
151
+ vae_tile_size=config.vae_tile_size,
152
+ vae_tile_stride=config.vae_tile_stride,
153
+ device=config.device,
154
+ dtype=config.model_dtype,
173
155
  )
174
156
  self.config = config
157
+ # sampler
175
158
  self.noise_scheduler = ScaledLinearScheduler()
176
159
  self.sampler = EulerSampler()
177
160
  # models
@@ -180,71 +163,54 @@ class SDImagePipeline(BasePipeline):
180
163
  self.unet = unet
181
164
  self.vae_decoder = vae_decoder
182
165
  self.vae_encoder = vae_encoder
183
- self.batch_cfg = batch_cfg
184
166
  self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
185
167
 
186
168
  @classmethod
187
- def from_pretrained(
188
- cls,
189
- model_path_or_config: str | os.PathLike | SDModelConfig,
190
- batch_cfg: bool = True,
191
- vae_tiled: bool = False,
192
- vae_tile_size: int = 256,
193
- vae_tile_stride: int = 256,
194
- device: str = "cuda",
195
- dtype: torch.dtype = torch.float16,
196
- offload_mode: str | None = None,
197
- ) -> "SDImagePipeline":
169
+ def from_pretrained(cls, model_path_or_config: SDPipelineConfig) -> "SDImagePipeline":
198
170
  if isinstance(model_path_or_config, str):
199
- model_config = SDModelConfig(unet_path=model_path_or_config)
171
+ config = SDPipelineConfig(model_path=model_path_or_config)
200
172
  else:
201
- model_config = model_path_or_config
173
+ config = model_path_or_config
202
174
 
203
- logger.info(f"loading state dict from {model_config.unet_path} ...")
204
- unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
175
+ logger.info(f"loading state dict from {config.model_path} ...")
176
+ unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
205
177
 
206
- if model_config.vae_path is not None:
207
- logger.info(f"loading state dict from {model_config.vae_path} ...")
208
- vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
178
+ if config.vae_path is not None:
179
+ logger.info(f"loading state dict from {config.vae_path} ...")
180
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
209
181
  else:
210
182
  vae_state_dict = unet_state_dict
211
183
 
212
- if model_config.clip_path is not None:
213
- logger.info(f"loading state dict from {model_config.clip_path} ...")
214
- clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
184
+ if config.clip_path is not None:
185
+ logger.info(f"loading state dict from {config.clip_path} ...")
186
+ clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
215
187
  else:
216
188
  clip_state_dict = unet_state_dict
217
189
 
218
- init_device = "cpu" if offload_mode else device
190
+ init_device = "cpu" if config.offload_mode is not None else config.device
219
191
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
220
192
  with LoRAContext():
221
- text_encoder = SDTextEncoder.from_state_dict(
222
- clip_state_dict, device=init_device, dtype=model_config.clip_dtype
223
- )
224
- unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
193
+ text_encoder = SDTextEncoder.from_state_dict(clip_state_dict, device=init_device, dtype=config.clip_dtype)
194
+ unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
225
195
  vae_decoder = SDVAEDecoder.from_state_dict(
226
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
196
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
227
197
  )
228
198
  vae_encoder = SDVAEEncoder.from_state_dict(
229
- vae_state_dict, device=init_device, dtype=model_config.vae_dtype, attn_impl="sdpa"
199
+ vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
230
200
  )
231
201
 
232
202
  pipe = cls(
233
- config=model_config,
203
+ config=config,
234
204
  tokenizer=tokenizer,
235
205
  text_encoder=text_encoder,
236
206
  unet=unet,
237
207
  vae_decoder=vae_decoder,
238
208
  vae_encoder=vae_encoder,
239
- batch_cfg=batch_cfg,
240
- vae_tiled=vae_tiled,
241
- vae_tile_size=vae_tile_size,
242
- vae_tile_stride=vae_tile_stride,
243
- device=device,
244
- dtype=dtype,
245
209
  )
246
- if offload_mode is not None:
247
- pipe.enable_cpu_offload(offload_mode)
210
+ pipe.eval()
211
+
212
+ if config.offload_mode is not None:
213
+ pipe.enable_cpu_offload(config.offload_mode)
248
214
  return pipe
249
215
 
250
216
  @classmethod
@@ -439,7 +405,7 @@ class SDImagePipeline(BasePipeline):
439
405
  controlnet_params=controlnet_params,
440
406
  current_step=i,
441
407
  total_step=len(timesteps),
442
- batch_cfg=self.batch_cfg,
408
+ batch_cfg=self.config.batch_cfg,
443
409
  )
444
410
  # Denoise
445
411
  latents = self.sampler.step(latents, noise_pred, i)