diffsynth-engine 0.6.1.dev2__py3-none-any.whl → 0.6.1.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,20 @@ class BasePipeline:
51
51
  def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
52
52
  raise NotImplementedError()
53
53
 
54
+ def update_weights(self, state_dicts: BaseStateDicts) -> None:
55
+ raise NotImplementedError()
56
+
57
+ @staticmethod
58
+ def update_component(
59
+ component: torch.nn.Module,
60
+ state_dict: Dict[str, torch.Tensor],
61
+ device: str,
62
+ dtype: torch.dtype,
63
+ ) -> None:
64
+ if component and state_dict:
65
+ component.load_state_dict(state_dict, assign=True)
66
+ component.to(device=device, dtype=dtype, non_blocking=True)
67
+
54
68
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
55
69
  for lora_path, lora_scale in lora_list:
56
70
  logger.info(f"loading lora from {lora_path} with scale {lora_scale}")
@@ -573,6 +573,13 @@ class FluxImagePipeline(BasePipeline):
573
573
  pipe.compile()
574
574
  return pipe
575
575
 
576
+ def update_weights(self, state_dicts: FluxStateDicts) -> None:
577
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
578
+ self.update_component(self.text_encoder_1, state_dicts.clip, self.config.device, self.config.clip_dtype)
579
+ self.update_component(self.text_encoder_2, state_dicts.t5, self.config.device, self.config.t5_dtype)
580
+ self.update_component(self.vae_decoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
581
+ self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
582
+
576
583
  def compile(self):
577
584
  self.dit.compile_repeated_blocks(dynamic=True)
578
585
 
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ from typing import Optional, Callable
2
3
  from tqdm import tqdm
3
4
  from PIL import Image
4
5
  from diffsynth_engine.algorithm.noise_scheduler.flow_match.recifited_flow import RecifitedFlowScheduler
@@ -179,6 +180,7 @@ class Hunyuan3DShapePipeline(BasePipeline):
179
180
  num_inference_steps: int = 50,
180
181
  guidance_scale: float = 7.5,
181
182
  seed: int = 42,
183
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
182
184
  ):
183
185
  image_emb = self.encode_image(image)
184
186
 
@@ -197,4 +199,6 @@ class Hunyuan3DShapePipeline(BasePipeline):
197
199
  noise_pred, noise_pred_uncond = model_outputs.chunk(2)
198
200
  model_outputs = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
199
201
  latents = self.sampler.step(latents, model_outputs, i)
202
+ if progress_callback is not None:
203
+ progress_callback(i, len(timesteps), "DENOISING")
200
204
  return self.decode_latents(latents)
@@ -254,6 +254,11 @@ class QwenImagePipeline(BasePipeline):
254
254
  pipe.compile()
255
255
  return pipe
256
256
 
257
+ def update_weights(self, state_dicts: QwenImageStateDicts) -> None:
258
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
259
+ self.update_component(self.encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype)
260
+ self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
261
+
257
262
  def compile(self):
258
263
  self.dit.compile_repeated_blocks(dynamic=True)
259
264
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev2
3
+ Version: 0.6.1.dev4
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -139,10 +139,10 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=sOJsSs1snI-ZGPJS8utstmgj0wcYwl
139
139
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
140
140
  diffsynth_engine/models/wan/wan_vae.py,sha256=ogXrVlwmzXR4iLxjSCkBPtYW8KWebnvvd2UtPZeoziY,38853
141
141
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
142
- diffsynth_engine/pipelines/base.py,sha256=RTkVwWaWXr5ujqn5-UBHvdPddYwr-uvChj9-fmoXrms,13729
143
- diffsynth_engine/pipelines/flux_image.py,sha256=a-MaHuguV7Z6LJukC_Tvp7d9_2dnrAaJZ4MZH_sKsKo,49116
144
- diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=fwNKET54KjCiWDpW2S1Fk-p3nfJreZ-RH7p46VLawEQ,7911
145
- diffsynth_engine/pipelines/qwen_image.py,sha256=EAYoq1QkdOSie_yVZG9enxJJRcncwVFPfDftMo-3zBA,23745
142
+ diffsynth_engine/pipelines/base.py,sha256=7x7gEdCk_DRnGDMdPGLvNPlk-Yn2p0yQ8pNvr59i-hU,14199
143
+ diffsynth_engine/pipelines/flux_image.py,sha256=LET1gPlkXJN2xE22GRVjUgWJH0ZZBFnAe3bIvJrjb1s,49726
144
+ diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=5wTn3pqqhn19THVbM0oxxcMZb8YUdoQY1GeiQgqS6hU,8176
145
+ diffsynth_engine/pipelines/qwen_image.py,sha256=vWe1M3FU-aqckvrVeJlZgMIN55qnRz9hZ2AxGCxro1Y,24134
146
146
  diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
147
147
  diffsynth_engine/pipelines/sdxl_image.py,sha256=FaihRd9Rt_qtqup2xEbHViVIFwFZVyvekYW4lCodNKY,21692
148
148
  diffsynth_engine/pipelines/utils.py,sha256=lk7sFGEk-fGjgadLpwwppHKG-yZ0RC-4ZmHW7pRRe8A,473
@@ -185,8 +185,8 @@ diffsynth_engine/utils/video.py,sha256=GoMyc2as4_VqfWX4pjQyAWh9QObsFMov42zADVZNa
185
185
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
186
186
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
187
187
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
188
- diffsynth_engine-0.6.1.dev2.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
189
- diffsynth_engine-0.6.1.dev2.dist-info/METADATA,sha256=6bwBcpLV0q3yBPFYlDecEtAJP1ga-zi7u9tioW41D7k,1163
190
- diffsynth_engine-0.6.1.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
191
- diffsynth_engine-0.6.1.dev2.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
192
- diffsynth_engine-0.6.1.dev2.dist-info/RECORD,,
188
+ diffsynth_engine-0.6.1.dev4.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
189
+ diffsynth_engine-0.6.1.dev4.dist-info/METADATA,sha256=TW2wclPSmra3hOTVtR1hef2-XuaaP61xjnss7y3r4SM,1163
190
+ diffsynth_engine-0.6.1.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
191
+ diffsynth_engine-0.6.1.dev4.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
192
+ diffsynth_engine-0.6.1.dev4.dist-info/RECORD,,