diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
|
|
|
17
17
|
flux_dit_config,
|
|
18
18
|
flux_text_encoder_config,
|
|
19
19
|
)
|
|
20
|
-
from diffsynth_engine.configs import
|
|
20
|
+
from diffsynth_engine.configs import (
|
|
21
|
+
FluxPipelineConfig,
|
|
22
|
+
FluxStateDicts,
|
|
23
|
+
ControlType,
|
|
24
|
+
ControlNetParams,
|
|
25
|
+
)
|
|
21
26
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
22
27
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
23
28
|
from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
|
|
@@ -34,16 +39,17 @@ from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
|
|
|
34
39
|
|
|
35
40
|
logger = logging.get_logger(__name__)
|
|
36
41
|
|
|
37
|
-
with open(FLUX_DIT_CONFIG_FILE, "r") as f:
|
|
42
|
+
with open(FLUX_DIT_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
38
43
|
config = json.load(f)
|
|
39
44
|
|
|
45
|
+
PREFERRED_KONTEXT_RESOLUTIONS = config["preferred_kontext_resolutions"]
|
|
46
|
+
|
|
40
47
|
|
|
41
48
|
class FluxLoRAConverter(LoRAStateDictConverter):
|
|
42
49
|
def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
43
50
|
flux_dim = 3072
|
|
44
51
|
dit_rename_dict = flux_dit_config["civitai"]["rename_dict"]
|
|
45
52
|
dit_suffix_rename_dict = flux_dit_config["civitai"]["suffix_rename_dict"]
|
|
46
|
-
clip_rename_dict = flux_text_encoder_config["diffusers"]["rename_dict"]
|
|
47
53
|
clip_attn_rename_dict = flux_text_encoder_config["diffusers"]["attn_rename_dict"]
|
|
48
54
|
|
|
49
55
|
dit_dict = {}
|
|
@@ -136,27 +142,18 @@ class FluxLoRAConverter(LoRAStateDictConverter):
|
|
|
136
142
|
lora_args["rank"] = lora_args["up"].shape[1]
|
|
137
143
|
rename = rename.replace(".weight", "")
|
|
138
144
|
dit_dict[rename] = lora_args
|
|
139
|
-
elif "
|
|
140
|
-
name = key.replace("
|
|
141
|
-
name = name.replace("
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
param = param.reshape((1, param.shape[0], param.shape[1]))
|
|
147
|
-
rename = clip_rename_dict[name]
|
|
148
|
-
elif name.startswith("text_model.encoder.layers."):
|
|
149
|
-
names = name.split(".")
|
|
150
|
-
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
|
151
|
-
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type], tail])
|
|
152
|
-
else:
|
|
153
|
-
raise ValueError(f"Unsupported key: {key}")
|
|
145
|
+
elif "lora_te1_text_model_encoder_layers_" in key:
|
|
146
|
+
name = key.replace("lora_te1_text_model_encoder_layers_", "")
|
|
147
|
+
name = name.replace(".alpha", "")
|
|
148
|
+
layer_id, layer_type = name.split("_", 1)
|
|
149
|
+
layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
|
|
150
|
+
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
|
|
151
|
+
|
|
154
152
|
lora_args = {}
|
|
155
153
|
lora_args["alpha"] = param
|
|
156
154
|
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
|
|
157
155
|
lora_args["down"] = lora_state_dict[origin_key.replace(".alpha", ".lora_down.weight")]
|
|
158
156
|
lora_args["rank"] = lora_args["up"].shape[1]
|
|
159
|
-
rename = rename.replace(".weight", "")
|
|
160
157
|
te_dict[rename] = lora_args
|
|
161
158
|
else:
|
|
162
159
|
raise ValueError(f"Unsupported key: {key}")
|
|
@@ -515,29 +512,20 @@ class FluxImagePipeline(BasePipeline):
|
|
|
515
512
|
vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
516
513
|
|
|
517
514
|
with LoRAContext():
|
|
518
|
-
attn_kwargs = {
|
|
519
|
-
"attn_impl": config.dit_attn_impl,
|
|
520
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
521
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
522
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
523
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
524
|
-
}
|
|
525
515
|
if config.use_fbcache:
|
|
526
516
|
dit = FluxDiTFBCache.from_state_dict(
|
|
527
517
|
state_dicts.model,
|
|
528
|
-
device=init_device,
|
|
518
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
529
519
|
dtype=config.model_dtype,
|
|
530
520
|
in_channel=config.control_type.get_in_channel(),
|
|
531
|
-
attn_kwargs=attn_kwargs,
|
|
532
521
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
533
522
|
)
|
|
534
523
|
else:
|
|
535
524
|
dit = FluxDiT.from_state_dict(
|
|
536
525
|
state_dicts.model,
|
|
537
|
-
device=init_device,
|
|
526
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
538
527
|
dtype=config.model_dtype,
|
|
539
528
|
in_channel=config.control_type.get_in_channel(),
|
|
540
|
-
attn_kwargs=attn_kwargs,
|
|
541
529
|
)
|
|
542
530
|
if config.use_fp8_linear:
|
|
543
531
|
enable_fp8_linear(dit)
|
|
@@ -573,8 +561,15 @@ class FluxImagePipeline(BasePipeline):
|
|
|
573
561
|
pipe.compile()
|
|
574
562
|
return pipe
|
|
575
563
|
|
|
564
|
+
def update_weights(self, state_dicts: FluxStateDicts) -> None:
|
|
565
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
566
|
+
self.update_component(self.text_encoder_1, state_dicts.clip, self.config.device, self.config.clip_dtype)
|
|
567
|
+
self.update_component(self.text_encoder_2, state_dicts.t5, self.config.device, self.config.t5_dtype)
|
|
568
|
+
self.update_component(self.vae_decoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
569
|
+
self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
570
|
+
|
|
576
571
|
def compile(self):
|
|
577
|
-
self.dit.compile_repeated_blocks(
|
|
572
|
+
self.dit.compile_repeated_blocks()
|
|
578
573
|
|
|
579
574
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
580
575
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
@@ -612,7 +607,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
612
607
|
return prompt_emb, add_text_embeds
|
|
613
608
|
|
|
614
609
|
def prepare_extra_input(self, latents, positive_prompt_emb, guidance=1.0):
|
|
615
|
-
image_ids =
|
|
610
|
+
image_ids = self.dit.prepare_image_ids(latents)
|
|
616
611
|
guidance = torch.tensor([guidance] * latents.shape[0], device=latents.device, dtype=latents.dtype)
|
|
617
612
|
text_ids = torch.zeros(positive_prompt_emb.shape[0], positive_prompt_emb.shape[1], 3).to(
|
|
618
613
|
device=self.device, dtype=positive_prompt_emb.dtype
|
|
@@ -639,45 +634,45 @@ class FluxImagePipeline(BasePipeline):
|
|
|
639
634
|
):
|
|
640
635
|
if cfg_scale <= 1.0:
|
|
641
636
|
return self.predict_noise(
|
|
642
|
-
latents,
|
|
643
|
-
timestep,
|
|
644
|
-
positive_prompt_emb,
|
|
645
|
-
positive_add_text_embeds,
|
|
646
|
-
image_emb,
|
|
647
|
-
image_ids,
|
|
648
|
-
text_ids,
|
|
649
|
-
guidance,
|
|
650
|
-
controlnet_params,
|
|
651
|
-
current_step,
|
|
652
|
-
total_step,
|
|
637
|
+
latents=latents,
|
|
638
|
+
timestep=timestep,
|
|
639
|
+
prompt_emb=positive_prompt_emb,
|
|
640
|
+
add_text_embeds=positive_add_text_embeds,
|
|
641
|
+
image_emb=image_emb,
|
|
642
|
+
image_ids=image_ids,
|
|
643
|
+
text_ids=text_ids,
|
|
644
|
+
guidance=guidance,
|
|
645
|
+
controlnet_params=controlnet_params,
|
|
646
|
+
current_step=current_step,
|
|
647
|
+
total_step=total_step,
|
|
653
648
|
)
|
|
654
649
|
if not batch_cfg:
|
|
655
650
|
# cfg by predict noise one by one
|
|
656
651
|
positive_noise_pred = self.predict_noise(
|
|
657
|
-
latents,
|
|
658
|
-
timestep,
|
|
659
|
-
positive_prompt_emb,
|
|
660
|
-
positive_add_text_embeds,
|
|
661
|
-
image_emb,
|
|
662
|
-
image_ids,
|
|
663
|
-
text_ids,
|
|
664
|
-
guidance,
|
|
665
|
-
controlnet_params,
|
|
666
|
-
current_step,
|
|
667
|
-
total_step,
|
|
652
|
+
latents=latents,
|
|
653
|
+
timestep=timestep,
|
|
654
|
+
prompt_emb=positive_prompt_emb,
|
|
655
|
+
add_text_embeds=positive_add_text_embeds,
|
|
656
|
+
image_emb=image_emb,
|
|
657
|
+
image_ids=image_ids,
|
|
658
|
+
text_ids=text_ids,
|
|
659
|
+
guidance=guidance,
|
|
660
|
+
controlnet_params=controlnet_params,
|
|
661
|
+
current_step=current_step,
|
|
662
|
+
total_step=total_step,
|
|
668
663
|
)
|
|
669
664
|
negative_noise_pred = self.predict_noise(
|
|
670
|
-
latents,
|
|
671
|
-
timestep,
|
|
672
|
-
negative_prompt_emb,
|
|
673
|
-
negative_add_text_embeds,
|
|
674
|
-
image_emb,
|
|
675
|
-
image_ids,
|
|
676
|
-
text_ids,
|
|
677
|
-
guidance,
|
|
678
|
-
controlnet_params,
|
|
679
|
-
current_step,
|
|
680
|
-
total_step,
|
|
665
|
+
latents=latents,
|
|
666
|
+
timestep=timestep,
|
|
667
|
+
prompt_emb=negative_prompt_emb,
|
|
668
|
+
add_text_embeds=negative_add_text_embeds,
|
|
669
|
+
image_emb=image_emb,
|
|
670
|
+
image_ids=image_ids,
|
|
671
|
+
text_ids=text_ids,
|
|
672
|
+
guidance=guidance,
|
|
673
|
+
controlnet_params=controlnet_params,
|
|
674
|
+
current_step=current_step,
|
|
675
|
+
total_step=total_step,
|
|
681
676
|
)
|
|
682
677
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
683
678
|
return noise_pred
|
|
@@ -692,17 +687,17 @@ class FluxImagePipeline(BasePipeline):
|
|
|
692
687
|
text_ids = torch.cat([text_ids, text_ids], dim=0)
|
|
693
688
|
guidance = torch.cat([guidance, guidance], dim=0)
|
|
694
689
|
positive_noise_pred, negative_noise_pred = self.predict_noise(
|
|
695
|
-
latents,
|
|
696
|
-
timestep,
|
|
697
|
-
prompt_emb,
|
|
698
|
-
add_text_embeds,
|
|
699
|
-
image_emb,
|
|
700
|
-
image_ids,
|
|
701
|
-
text_ids,
|
|
702
|
-
guidance,
|
|
703
|
-
controlnet_params,
|
|
704
|
-
current_step,
|
|
705
|
-
total_step,
|
|
690
|
+
latents=latents,
|
|
691
|
+
timestep=timestep,
|
|
692
|
+
prompt_emb=prompt_emb,
|
|
693
|
+
add_text_embeds=add_text_embeds,
|
|
694
|
+
image_emb=image_emb,
|
|
695
|
+
image_ids=image_ids,
|
|
696
|
+
text_ids=text_ids,
|
|
697
|
+
guidance=guidance,
|
|
698
|
+
controlnet_params=controlnet_params,
|
|
699
|
+
current_step=current_step,
|
|
700
|
+
total_step=total_step,
|
|
706
701
|
)
|
|
707
702
|
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
708
703
|
return noise_pred
|
|
@@ -721,32 +716,42 @@ class FluxImagePipeline(BasePipeline):
|
|
|
721
716
|
current_step: int,
|
|
722
717
|
total_step: int,
|
|
723
718
|
):
|
|
724
|
-
|
|
725
|
-
|
|
719
|
+
height, width = latents.shape[2:]
|
|
720
|
+
latents = self.dit.patchify(latents)
|
|
721
|
+
image_seq_len = latents.shape[1]
|
|
722
|
+
|
|
723
|
+
double_block_output, single_block_output = None, None
|
|
724
|
+
if self.config.control_type == ControlType.normal:
|
|
725
|
+
double_block_output, single_block_output = self.predict_multicontrolnet(
|
|
726
|
+
latents=latents,
|
|
727
|
+
timestep=timestep,
|
|
728
|
+
prompt_emb=prompt_emb,
|
|
729
|
+
add_text_embeds=add_text_embeds,
|
|
730
|
+
guidance=guidance,
|
|
731
|
+
text_ids=text_ids,
|
|
732
|
+
image_ids=image_ids,
|
|
733
|
+
controlnet_params=controlnet_params,
|
|
734
|
+
current_step=current_step,
|
|
735
|
+
total_step=total_step,
|
|
736
|
+
)
|
|
737
|
+
elif self.config.control_type == ControlType.bfl_kontext:
|
|
738
|
+
for idx, controlnet_param in enumerate(controlnet_params):
|
|
739
|
+
control_latents = controlnet_param.image * controlnet_param.scale
|
|
740
|
+
control_image_ids = self.dit.prepare_image_ids(control_latents)
|
|
741
|
+
control_image_ids[..., 0] = idx + 1
|
|
742
|
+
control_latents = self.dit.patchify(control_latents)
|
|
743
|
+
latents = torch.cat((latents, control_latents), dim=1)
|
|
744
|
+
image_ids = torch.cat((image_ids, control_image_ids), dim=1)
|
|
745
|
+
else:
|
|
726
746
|
controlnet_param = controlnet_params[0]
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
|
|
731
|
-
else:
|
|
732
|
-
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1)
|
|
733
|
-
latents = latents.to(self.dtype)
|
|
734
|
-
controlnet_params = []
|
|
747
|
+
control_latents = controlnet_param.image * controlnet_param.scale
|
|
748
|
+
control_latents = self.dit.patchify(control_latents)
|
|
749
|
+
latents = torch.cat((latents, control_latents), dim=2)
|
|
735
750
|
|
|
736
|
-
|
|
737
|
-
latents=latents,
|
|
738
|
-
timestep=timestep,
|
|
739
|
-
prompt_emb=prompt_emb,
|
|
740
|
-
add_text_embeds=add_text_embeds,
|
|
741
|
-
guidance=guidance,
|
|
742
|
-
text_ids=text_ids,
|
|
743
|
-
image_ids=image_ids,
|
|
744
|
-
controlnet_params=controlnet_params,
|
|
745
|
-
current_step=current_step,
|
|
746
|
-
total_step=total_step,
|
|
747
|
-
)
|
|
751
|
+
latents = latents.to(self.dtype)
|
|
748
752
|
self.load_models_to_device(["dit"])
|
|
749
753
|
|
|
754
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
750
755
|
noise_pred = self.dit(
|
|
751
756
|
hidden_states=latents,
|
|
752
757
|
timestep=timestep,
|
|
@@ -758,9 +763,10 @@ class FluxImagePipeline(BasePipeline):
|
|
|
758
763
|
image_ids=image_ids,
|
|
759
764
|
controlnet_double_block_output=double_block_output,
|
|
760
765
|
controlnet_single_block_output=single_block_output,
|
|
766
|
+
attn_kwargs=attn_kwargs,
|
|
761
767
|
)
|
|
762
|
-
|
|
763
|
-
|
|
768
|
+
noise_pred = noise_pred[:, :image_seq_len]
|
|
769
|
+
noise_pred = self.dit.unpatchify(noise_pred, height, width)
|
|
764
770
|
return noise_pred
|
|
765
771
|
|
|
766
772
|
def prepare_latents(
|
|
@@ -782,7 +788,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
782
788
|
sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
|
|
783
789
|
timesteps = timesteps[t_start - 1 :]
|
|
784
790
|
noise = latents
|
|
785
|
-
image = self.preprocess_image(input_image).to(device=self.device
|
|
791
|
+
image = self.preprocess_image(input_image).to(device=self.device)
|
|
786
792
|
latents = self.encode_image(image)
|
|
787
793
|
init_latents = latents.clone()
|
|
788
794
|
latents = self.sampler.add_noise(latents, noise, sigma_start)
|
|
@@ -804,26 +810,32 @@ class FluxImagePipeline(BasePipeline):
|
|
|
804
810
|
def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, height: int, width: int):
|
|
805
811
|
self.load_models_to_device(["vae_encoder"])
|
|
806
812
|
if mask is None:
|
|
813
|
+
if self.config.control_type == ControlType.bfl_kontext:
|
|
814
|
+
width, height = image.size
|
|
815
|
+
aspect_ratio = width / height
|
|
816
|
+
# Kontext is trained on specific resolutions, using one of them is recommended
|
|
817
|
+
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
|
|
818
|
+
width, height = 16 * (width // 16), 16 * (height // 16)
|
|
807
819
|
image = image.resize((width, height))
|
|
808
|
-
image = self.preprocess_image(image).to(device=self.device
|
|
820
|
+
image = self.preprocess_image(image).to(device=self.device)
|
|
809
821
|
latent = self.encode_image(image)
|
|
810
822
|
else:
|
|
811
823
|
if self.config.control_type == ControlType.normal:
|
|
812
824
|
image = image.resize((width, height))
|
|
813
825
|
mask = mask.resize((width, height))
|
|
814
|
-
image = self.preprocess_image(image).to(device=self.device
|
|
815
|
-
mask = self.preprocess_mask(mask).to(device=self.device
|
|
826
|
+
image = self.preprocess_image(image).to(device=self.device)
|
|
827
|
+
mask = self.preprocess_mask(mask).to(device=self.device)
|
|
816
828
|
masked_image = image.clone()
|
|
817
829
|
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
|
|
818
830
|
latent = self.encode_image(masked_image)
|
|
819
|
-
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
|
|
831
|
+
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
|
|
820
832
|
mask = 1 - mask
|
|
821
833
|
latent = torch.cat([latent, mask], dim=1)
|
|
822
834
|
elif self.config.control_type == ControlType.bfl_fill:
|
|
823
835
|
image = image.resize((width, height))
|
|
824
836
|
mask = mask.resize((width, height))
|
|
825
|
-
image = self.preprocess_image(image).to(device=self.device
|
|
826
|
-
mask = self.preprocess_mask(mask).to(device=self.device
|
|
837
|
+
image = self.preprocess_image(image).to(device=self.device)
|
|
838
|
+
mask = self.preprocess_mask(mask).to(device=self.device)
|
|
827
839
|
image = image * (1 - mask)
|
|
828
840
|
image = self.encode_image(image)
|
|
829
841
|
mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
|
@@ -862,6 +874,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
862
874
|
if len(controlnet_params) > 0:
|
|
863
875
|
self.load_models_to_device([])
|
|
864
876
|
for param in controlnet_params:
|
|
877
|
+
control_condition = param.model.patchify(param.image)
|
|
865
878
|
current_scale = param.scale
|
|
866
879
|
if not (
|
|
867
880
|
current_step >= param.control_start * total_step and current_step <= param.control_end * total_step
|
|
@@ -872,16 +885,19 @@ class FluxImagePipeline(BasePipeline):
|
|
|
872
885
|
if self.offload_mode is not None:
|
|
873
886
|
empty_cache()
|
|
874
887
|
param.model.to(self.device)
|
|
888
|
+
|
|
889
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
875
890
|
double_block_output, single_block_output = param.model(
|
|
876
|
-
latents,
|
|
877
|
-
|
|
878
|
-
current_scale,
|
|
879
|
-
timestep,
|
|
880
|
-
prompt_emb,
|
|
881
|
-
add_text_embeds,
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
891
|
+
hidden_states=latents,
|
|
892
|
+
control_condition=control_condition,
|
|
893
|
+
control_scale=current_scale,
|
|
894
|
+
timestep=timestep,
|
|
895
|
+
prompt_emb=prompt_emb,
|
|
896
|
+
pooled_prompt_emb=add_text_embeds,
|
|
897
|
+
image_ids=image_ids,
|
|
898
|
+
text_ids=text_ids,
|
|
899
|
+
guidance=guidance,
|
|
900
|
+
attn_kwargs=attn_kwargs,
|
|
885
901
|
)
|
|
886
902
|
if self.offload_mode is not None:
|
|
887
903
|
param.model.to("cpu")
|
|
@@ -927,8 +943,10 @@ class FluxImagePipeline(BasePipeline):
|
|
|
927
943
|
self.dit.refresh_cache_status(num_inference_steps)
|
|
928
944
|
if not isinstance(controlnet_params, list):
|
|
929
945
|
controlnet_params = [controlnet_params]
|
|
930
|
-
if self.config.control_type
|
|
931
|
-
assert controlnet_params and len(controlnet_params) == 1,
|
|
946
|
+
if self.config.control_type in [ControlType.bfl_control, ControlType.bfl_fill]:
|
|
947
|
+
assert controlnet_params and len(controlnet_params) == 1, (
|
|
948
|
+
"bfl_controlnet or bfl_fill must have one controlnet"
|
|
949
|
+
)
|
|
932
950
|
|
|
933
951
|
if input_image is not None:
|
|
934
952
|
width, height = input_image.size
|
|
@@ -966,8 +984,9 @@ class FluxImagePipeline(BasePipeline):
|
|
|
966
984
|
elif self.ip_adapter is not None:
|
|
967
985
|
image_emb = self.ip_adapter.encode_image(ref_image)
|
|
968
986
|
elif self.redux is not None:
|
|
969
|
-
|
|
970
|
-
|
|
987
|
+
ref_prompt_embeds = self.redux(ref_image)
|
|
988
|
+
flattened_ref_emb = ref_prompt_embeds.view(1, -1, ref_prompt_embeds.size(-1))
|
|
989
|
+
positive_prompt_emb = torch.cat([positive_prompt_emb, flattened_ref_emb], dim=1)
|
|
971
990
|
|
|
972
991
|
# Extra input
|
|
973
992
|
image_ids, text_ids, guidance = self.prepare_extra_input(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
2
3
|
from tqdm import tqdm
|
|
3
4
|
from PIL import Image
|
|
4
5
|
from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
|
|
@@ -179,6 +180,7 @@ class Hunyuan3DShapePipeline(BasePipeline):
|
|
|
179
180
|
num_inference_steps: int = 50,
|
|
180
181
|
guidance_scale: float = 7.5,
|
|
181
182
|
seed: int = 42,
|
|
183
|
+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
182
184
|
):
|
|
183
185
|
image_emb = self.encode_image(image)
|
|
184
186
|
|
|
@@ -197,4 +199,6 @@ class Hunyuan3DShapePipeline(BasePipeline):
|
|
|
197
199
|
noise_pred, noise_pred_uncond = model_outputs.chunk(2)
|
|
198
200
|
model_outputs = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
|
199
201
|
latents = self.sampler.step(latents, model_outputs, i)
|
|
202
|
+
if progress_callback is not None:
|
|
203
|
+
progress_callback(i, len(timesteps), "DENOISING")
|
|
200
204
|
return self.decode_latents(latents)
|