diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
17
17
  flux_dit_config,
18
18
  flux_text_encoder_config,
19
19
  )
20
- from diffsynth_engine.configs import FluxPipelineConfig, FluxStateDicts, ControlType, ControlNetParams
20
+ from diffsynth_engine.configs import (
21
+ FluxPipelineConfig,
22
+ FluxStateDicts,
23
+ ControlType,
24
+ ControlNetParams,
25
+ )
21
26
  from diffsynth_engine.models.basic.lora import LoRAContext
22
27
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23
28
  from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
@@ -34,16 +39,17 @@ from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
34
39
 
35
40
  logger = logging.get_logger(__name__)
36
41
 
37
- with open(FLUX_DIT_CONFIG_FILE, "r") as f:
42
+ with open(FLUX_DIT_CONFIG_FILE, "r", encoding="utf-8") as f:
38
43
  config = json.load(f)
39
44
 
45
+ PREFERRED_KONTEXT_RESOLUTIONS = config["preferred_kontext_resolutions"]
46
+
40
47
 
41
48
  class FluxLoRAConverter(LoRAStateDictConverter):
42
49
  def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
43
50
  flux_dim = 3072
44
51
  dit_rename_dict = flux_dit_config["civitai"]["rename_dict"]
45
52
  dit_suffix_rename_dict = flux_dit_config["civitai"]["suffix_rename_dict"]
46
- clip_rename_dict = flux_text_encoder_config["diffusers"]["rename_dict"]
47
53
  clip_attn_rename_dict = flux_text_encoder_config["diffusers"]["attn_rename_dict"]
48
54
 
49
55
  dit_dict = {}
@@ -136,27 +142,18 @@ class FluxLoRAConverter(LoRAStateDictConverter):
136
142
  lora_args["rank"] = lora_args["up"].shape[1]
137
143
  rename = rename.replace(".weight", "")
138
144
  dit_dict[rename] = lora_args
139
- elif "lora_te" in key:
140
- name = key.replace("lora_te1", "text_encoder")
141
- name = name.replace("text_model_encoder_layers", "text_model.encoder.layers")
142
- name = name.replace(".alpha", ".weight")
143
- rename = ""
144
- if name in clip_rename_dict:
145
- if name == "text_model.embeddings.position_embedding.weight":
146
- param = param.reshape((1, param.shape[0], param.shape[1]))
147
- rename = clip_rename_dict[name]
148
- elif name.startswith("text_model.encoder.layers."):
149
- names = name.split(".")
150
- layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
151
- rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type], tail])
152
- else:
153
- raise ValueError(f"Unsupported key: {key}")
145
+ elif "lora_te1_text_model_encoder_layers_" in key:
146
+ name = key.replace("lora_te1_text_model_encoder_layers_", "")
147
+ name = name.replace(".alpha", "")
148
+ layer_id, layer_type = name.split("_", 1)
149
+ layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
150
+ rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
151
+
154
152
  lora_args = {}
155
153
  lora_args["alpha"] = param
156
154
  lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
157
155
  lora_args["down"] = lora_state_dict[origin_key.replace(".alpha", ".lora_down.weight")]
158
156
  lora_args["rank"] = lora_args["up"].shape[1]
159
- rename = rename.replace(".weight", "")
160
157
  te_dict[rename] = lora_args
161
158
  else:
162
159
  raise ValueError(f"Unsupported key: {key}")
@@ -515,29 +512,20 @@ class FluxImagePipeline(BasePipeline):
515
512
  vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
516
513
 
517
514
  with LoRAContext():
518
- attn_kwargs = {
519
- "attn_impl": config.dit_attn_impl,
520
- "sparge_smooth_k": config.sparge_smooth_k,
521
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
522
- "sparge_simthreshd1": config.sparge_simthreshd1,
523
- "sparge_pvthreshd": config.sparge_pvthreshd,
524
- }
525
515
  if config.use_fbcache:
526
516
  dit = FluxDiTFBCache.from_state_dict(
527
517
  state_dicts.model,
528
- device=init_device,
518
+ device=("cpu" if config.use_fsdp else init_device),
529
519
  dtype=config.model_dtype,
530
520
  in_channel=config.control_type.get_in_channel(),
531
- attn_kwargs=attn_kwargs,
532
521
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
533
522
  )
534
523
  else:
535
524
  dit = FluxDiT.from_state_dict(
536
525
  state_dicts.model,
537
- device=init_device,
526
+ device=("cpu" if config.use_fsdp else init_device),
538
527
  dtype=config.model_dtype,
539
528
  in_channel=config.control_type.get_in_channel(),
540
- attn_kwargs=attn_kwargs,
541
529
  )
542
530
  if config.use_fp8_linear:
543
531
  enable_fp8_linear(dit)
@@ -573,8 +561,15 @@ class FluxImagePipeline(BasePipeline):
573
561
  pipe.compile()
574
562
  return pipe
575
563
 
564
+ def update_weights(self, state_dicts: FluxStateDicts) -> None:
565
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
566
+ self.update_component(self.text_encoder_1, state_dicts.clip, self.config.device, self.config.clip_dtype)
567
+ self.update_component(self.text_encoder_2, state_dicts.t5, self.config.device, self.config.t5_dtype)
568
+ self.update_component(self.vae_decoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
569
+ self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
570
+
576
571
  def compile(self):
577
- self.dit.compile_repeated_blocks(dynamic=True)
572
+ self.dit.compile_repeated_blocks()
578
573
 
579
574
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
580
575
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
@@ -612,7 +607,7 @@ class FluxImagePipeline(BasePipeline):
612
607
  return prompt_emb, add_text_embeds
613
608
 
614
609
  def prepare_extra_input(self, latents, positive_prompt_emb, guidance=1.0):
615
- image_ids = FluxDiT.prepare_image_ids(latents)
610
+ image_ids = self.dit.prepare_image_ids(latents)
616
611
  guidance = torch.tensor([guidance] * latents.shape[0], device=latents.device, dtype=latents.dtype)
617
612
  text_ids = torch.zeros(positive_prompt_emb.shape[0], positive_prompt_emb.shape[1], 3).to(
618
613
  device=self.device, dtype=positive_prompt_emb.dtype
@@ -639,45 +634,45 @@ class FluxImagePipeline(BasePipeline):
639
634
  ):
640
635
  if cfg_scale <= 1.0:
641
636
  return self.predict_noise(
642
- latents,
643
- timestep,
644
- positive_prompt_emb,
645
- positive_add_text_embeds,
646
- image_emb,
647
- image_ids,
648
- text_ids,
649
- guidance,
650
- controlnet_params,
651
- current_step,
652
- total_step,
637
+ latents=latents,
638
+ timestep=timestep,
639
+ prompt_emb=positive_prompt_emb,
640
+ add_text_embeds=positive_add_text_embeds,
641
+ image_emb=image_emb,
642
+ image_ids=image_ids,
643
+ text_ids=text_ids,
644
+ guidance=guidance,
645
+ controlnet_params=controlnet_params,
646
+ current_step=current_step,
647
+ total_step=total_step,
653
648
  )
654
649
  if not batch_cfg:
655
650
  # cfg by predict noise one by one
656
651
  positive_noise_pred = self.predict_noise(
657
- latents,
658
- timestep,
659
- positive_prompt_emb,
660
- positive_add_text_embeds,
661
- image_emb,
662
- image_ids,
663
- text_ids,
664
- guidance,
665
- controlnet_params,
666
- current_step,
667
- total_step,
652
+ latents=latents,
653
+ timestep=timestep,
654
+ prompt_emb=positive_prompt_emb,
655
+ add_text_embeds=positive_add_text_embeds,
656
+ image_emb=image_emb,
657
+ image_ids=image_ids,
658
+ text_ids=text_ids,
659
+ guidance=guidance,
660
+ controlnet_params=controlnet_params,
661
+ current_step=current_step,
662
+ total_step=total_step,
668
663
  )
669
664
  negative_noise_pred = self.predict_noise(
670
- latents,
671
- timestep,
672
- negative_prompt_emb,
673
- negative_add_text_embeds,
674
- image_emb,
675
- image_ids,
676
- text_ids,
677
- guidance,
678
- controlnet_params,
679
- current_step,
680
- total_step,
665
+ latents=latents,
666
+ timestep=timestep,
667
+ prompt_emb=negative_prompt_emb,
668
+ add_text_embeds=negative_add_text_embeds,
669
+ image_emb=image_emb,
670
+ image_ids=image_ids,
671
+ text_ids=text_ids,
672
+ guidance=guidance,
673
+ controlnet_params=controlnet_params,
674
+ current_step=current_step,
675
+ total_step=total_step,
681
676
  )
682
677
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
683
678
  return noise_pred
@@ -692,17 +687,17 @@ class FluxImagePipeline(BasePipeline):
692
687
  text_ids = torch.cat([text_ids, text_ids], dim=0)
693
688
  guidance = torch.cat([guidance, guidance], dim=0)
694
689
  positive_noise_pred, negative_noise_pred = self.predict_noise(
695
- latents,
696
- timestep,
697
- prompt_emb,
698
- add_text_embeds,
699
- image_emb,
700
- image_ids,
701
- text_ids,
702
- guidance,
703
- controlnet_params,
704
- current_step,
705
- total_step,
690
+ latents=latents,
691
+ timestep=timestep,
692
+ prompt_emb=prompt_emb,
693
+ add_text_embeds=add_text_embeds,
694
+ image_emb=image_emb,
695
+ image_ids=image_ids,
696
+ text_ids=text_ids,
697
+ guidance=guidance,
698
+ controlnet_params=controlnet_params,
699
+ current_step=current_step,
700
+ total_step=total_step,
706
701
  )
707
702
  noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
708
703
  return noise_pred
@@ -721,32 +716,42 @@ class FluxImagePipeline(BasePipeline):
721
716
  current_step: int,
722
717
  total_step: int,
723
718
  ):
724
- origin_latents_shape = latents.shape
725
- if self.config.control_type != ControlType.normal:
719
+ height, width = latents.shape[2:]
720
+ latents = self.dit.patchify(latents)
721
+ image_seq_len = latents.shape[1]
722
+
723
+ double_block_output, single_block_output = None, None
724
+ if self.config.control_type == ControlType.normal:
725
+ double_block_output, single_block_output = self.predict_multicontrolnet(
726
+ latents=latents,
727
+ timestep=timestep,
728
+ prompt_emb=prompt_emb,
729
+ add_text_embeds=add_text_embeds,
730
+ guidance=guidance,
731
+ text_ids=text_ids,
732
+ image_ids=image_ids,
733
+ controlnet_params=controlnet_params,
734
+ current_step=current_step,
735
+ total_step=total_step,
736
+ )
737
+ elif self.config.control_type == ControlType.bfl_kontext:
738
+ for idx, controlnet_param in enumerate(controlnet_params):
739
+ control_latents = controlnet_param.image * controlnet_param.scale
740
+ control_image_ids = self.dit.prepare_image_ids(control_latents)
741
+ control_image_ids[..., 0] = idx + 1
742
+ control_latents = self.dit.patchify(control_latents)
743
+ latents = torch.cat((latents, control_latents), dim=1)
744
+ image_ids = torch.cat((image_ids, control_image_ids), dim=1)
745
+ else:
726
746
  controlnet_param = controlnet_params[0]
727
- if self.config.control_type == ControlType.bfl_kontext:
728
- latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
729
- image_ids = image_ids.repeat(1, 2, 1)
730
- image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
731
- else:
732
- latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1)
733
- latents = latents.to(self.dtype)
734
- controlnet_params = []
747
+ control_latents = controlnet_param.image * controlnet_param.scale
748
+ control_latents = self.dit.patchify(control_latents)
749
+ latents = torch.cat((latents, control_latents), dim=2)
735
750
 
736
- double_block_output, single_block_output = self.predict_multicontrolnet(
737
- latents=latents,
738
- timestep=timestep,
739
- prompt_emb=prompt_emb,
740
- add_text_embeds=add_text_embeds,
741
- guidance=guidance,
742
- text_ids=text_ids,
743
- image_ids=image_ids,
744
- controlnet_params=controlnet_params,
745
- current_step=current_step,
746
- total_step=total_step,
747
- )
751
+ latents = latents.to(self.dtype)
748
752
  self.load_models_to_device(["dit"])
749
753
 
754
+ attn_kwargs = self.get_attn_kwargs(latents)
750
755
  noise_pred = self.dit(
751
756
  hidden_states=latents,
752
757
  timestep=timestep,
@@ -758,9 +763,10 @@ class FluxImagePipeline(BasePipeline):
758
763
  image_ids=image_ids,
759
764
  controlnet_double_block_output=double_block_output,
760
765
  controlnet_single_block_output=single_block_output,
766
+ attn_kwargs=attn_kwargs,
761
767
  )
762
- if self.config.control_type == ControlType.bfl_kontext:
763
- noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
768
+ noise_pred = noise_pred[:, :image_seq_len]
769
+ noise_pred = self.dit.unpatchify(noise_pred, height, width)
764
770
  return noise_pred
765
771
 
766
772
  def prepare_latents(
@@ -782,7 +788,7 @@ class FluxImagePipeline(BasePipeline):
782
788
  sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :]
783
789
  timesteps = timesteps[t_start - 1 :]
784
790
  noise = latents
785
- image = self.preprocess_image(input_image).to(device=self.device, dtype=self.dtype)
791
+ image = self.preprocess_image(input_image).to(device=self.device)
786
792
  latents = self.encode_image(image)
787
793
  init_latents = latents.clone()
788
794
  latents = self.sampler.add_noise(latents, noise, sigma_start)
@@ -804,26 +810,32 @@ class FluxImagePipeline(BasePipeline):
804
810
  def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, height: int, width: int):
805
811
  self.load_models_to_device(["vae_encoder"])
806
812
  if mask is None:
813
+ if self.config.control_type == ControlType.bfl_kontext:
814
+ width, height = image.size
815
+ aspect_ratio = width / height
816
+ # Kontext is trained on specific resolutions, using one of them is recommended
817
+ _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
818
+ width, height = 16 * (width // 16), 16 * (height // 16)
807
819
  image = image.resize((width, height))
808
- image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
820
+ image = self.preprocess_image(image).to(device=self.device)
809
821
  latent = self.encode_image(image)
810
822
  else:
811
823
  if self.config.control_type == ControlType.normal:
812
824
  image = image.resize((width, height))
813
825
  mask = mask.resize((width, height))
814
- image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
815
- mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
826
+ image = self.preprocess_image(image).to(device=self.device)
827
+ mask = self.preprocess_mask(mask).to(device=self.device)
816
828
  masked_image = image.clone()
817
829
  masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
818
830
  latent = self.encode_image(masked_image)
819
- mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
831
+ mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
820
832
  mask = 1 - mask
821
833
  latent = torch.cat([latent, mask], dim=1)
822
834
  elif self.config.control_type == ControlType.bfl_fill:
823
835
  image = image.resize((width, height))
824
836
  mask = mask.resize((width, height))
825
- image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
826
- mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
837
+ image = self.preprocess_image(image).to(device=self.device)
838
+ mask = self.preprocess_mask(mask).to(device=self.device)
827
839
  image = image * (1 - mask)
828
840
  image = self.encode_image(image)
829
841
  mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
@@ -862,6 +874,7 @@ class FluxImagePipeline(BasePipeline):
862
874
  if len(controlnet_params) > 0:
863
875
  self.load_models_to_device([])
864
876
  for param in controlnet_params:
877
+ control_condition = param.model.patchify(param.image)
865
878
  current_scale = param.scale
866
879
  if not (
867
880
  current_step >= param.control_start * total_step and current_step <= param.control_end * total_step
@@ -872,16 +885,19 @@ class FluxImagePipeline(BasePipeline):
872
885
  if self.offload_mode is not None:
873
886
  empty_cache()
874
887
  param.model.to(self.device)
888
+
889
+ attn_kwargs = self.get_attn_kwargs(latents)
875
890
  double_block_output, single_block_output = param.model(
876
- latents,
877
- param.image,
878
- current_scale,
879
- timestep,
880
- prompt_emb,
881
- add_text_embeds,
882
- guidance,
883
- image_ids,
884
- text_ids,
891
+ hidden_states=latents,
892
+ control_condition=control_condition,
893
+ control_scale=current_scale,
894
+ timestep=timestep,
895
+ prompt_emb=prompt_emb,
896
+ pooled_prompt_emb=add_text_embeds,
897
+ image_ids=image_ids,
898
+ text_ids=text_ids,
899
+ guidance=guidance,
900
+ attn_kwargs=attn_kwargs,
885
901
  )
886
902
  if self.offload_mode is not None:
887
903
  param.model.to("cpu")
@@ -927,8 +943,10 @@ class FluxImagePipeline(BasePipeline):
927
943
  self.dit.refresh_cache_status(num_inference_steps)
928
944
  if not isinstance(controlnet_params, list):
929
945
  controlnet_params = [controlnet_params]
930
- if self.config.control_type != ControlType.normal:
931
- assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
946
+ if self.config.control_type in [ControlType.bfl_control, ControlType.bfl_fill]:
947
+ assert controlnet_params and len(controlnet_params) == 1, (
948
+ "bfl_controlnet or bfl_fill must have one controlnet"
949
+ )
932
950
 
933
951
  if input_image is not None:
934
952
  width, height = input_image.size
@@ -966,8 +984,9 @@ class FluxImagePipeline(BasePipeline):
966
984
  elif self.ip_adapter is not None:
967
985
  image_emb = self.ip_adapter.encode_image(ref_image)
968
986
  elif self.redux is not None:
969
- image_prompt_embeds = self.redux(ref_image)
970
- positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
987
+ ref_prompt_embeds = self.redux(ref_image)
988
+ flattened_ref_emb = ref_prompt_embeds.view(1, -1, ref_prompt_embeds.size(-1))
989
+ positive_prompt_emb = torch.cat([positive_prompt_emb, flattened_ref_emb], dim=1)
971
990
 
972
991
  # Extra input
973
992
  image_ids, text_ids, guidance = self.prepare_extra_input(
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from typing import Optional, Callable
2
3
  from tqdm import tqdm
3
4
  from PIL import Image
4
5
  from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
@@ -179,6 +180,7 @@ class Hunyuan3DShapePipeline(BasePipeline):
179
180
  num_inference_steps: int = 50,
180
181
  guidance_scale: float = 7.5,
181
182
  seed: int = 42,
183
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
182
184
  ):
183
185
  image_emb = self.encode_image(image)
184
186
 
@@ -197,4 +199,6 @@ class Hunyuan3DShapePipeline(BasePipeline):
197
199
  noise_pred, noise_pred_uncond = model_outputs.chunk(2)
198
200
  model_outputs = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
199
201
  latents = self.sampler.step(latents, model_outputs, i)
202
+ if progress_callback is not None:
203
+ progress_callback(i, len(timesteps), "DENOISING")
200
204
  return self.decode_latents(latents)