diffsynth-engine 0.6.1.dev19__py3-none-any.whl → 0.6.1.dev21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@ import os
2
2
  import torch
3
3
  import numpy as np
4
4
  from einops import rearrange
5
- from typing import Dict, List, Tuple, Union
5
+ from typing import Dict, List, Tuple, Union, Optional
6
6
  from PIL import Image
7
7
 
8
8
  from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
@@ -70,7 +70,11 @@ class BasePipeline:
70
70
  lora_list: List[Tuple[str, Union[float, LoraConfig]]],
71
71
  fused: bool = True,
72
72
  save_original_weight: bool = False,
73
+ lora_converter: Optional[LoRAStateDictConverter] = None,
73
74
  ):
75
+ if not lora_converter:
76
+ lora_converter = self.lora_converter
77
+
74
78
  for lora_path, lora_item in lora_list:
75
79
  if isinstance(lora_item, float):
76
80
  lora_scale = lora_item
@@ -86,7 +90,7 @@ class BasePipeline:
86
90
  self.apply_scheduler_config(scheduler_config)
87
91
  logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
88
92
 
89
- lora_state_dict = self.lora_converter.convert(state_dict)
93
+ lora_state_dict = lora_converter.convert(state_dict)
90
94
  for model_name, state_dict in lora_state_dict.items():
91
95
  model = getattr(self, model_name)
92
96
  lora_args = []
@@ -830,7 +830,7 @@ class FluxImagePipeline(BasePipeline):
830
830
  masked_image = image.clone()
831
831
  masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
832
832
  latent = self.encode_image(masked_image)
833
- mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
833
+ mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
834
834
  mask = 1 - mask
835
835
  latent = torch.cat([latent, mask], dim=1)
836
836
  elif self.config.control_type == ControlType.bfl_fill:
@@ -95,8 +95,14 @@ class WanLoRAConverter(LoRAStateDictConverter):
95
95
  return state_dict
96
96
 
97
97
 
98
+ class WanLowNoiseLoRAConverter(WanLoRAConverter):
99
+ def convert(self, state_dict):
100
+ return {"dit2": super().convert(state_dict)["dit"]}
101
+
102
+
98
103
  class WanVideoPipeline(BasePipeline):
99
104
  lora_converter = WanLoRAConverter()
105
+ low_noise_lora_converter = WanLowNoiseLoRAConverter()
100
106
 
101
107
  def __init__(
102
108
  self,
@@ -133,7 +139,13 @@ class WanVideoPipeline(BasePipeline):
133
139
  self.image_encoder = image_encoder
134
140
  self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
135
141
 
136
- def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
142
+ def load_loras(
143
+ self,
144
+ lora_list: List[Tuple[str, float]],
145
+ fused: bool = True,
146
+ save_original_weight: bool = False,
147
+ lora_converter: Optional[WanLoRAConverter] = None
148
+ ):
137
149
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
138
150
  "load LoRA is not allowed when tensor parallel is enabled; "
139
151
  "set tp_degree=None or tp_degree=1 during pipeline initialization"
@@ -142,10 +154,20 @@ class WanVideoPipeline(BasePipeline):
142
154
  "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
143
155
  "either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
144
156
  )
145
- super().load_loras(lora_list, fused, save_original_weight)
157
+ super().load_loras(lora_list, fused, save_original_weight, lora_converter)
158
+
159
+ def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
160
+ assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
161
+ self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
162
+
163
+ def load_loras_high_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
164
+ assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
165
+ self.load_loras(lora_list, fused, save_original_weight)
146
166
 
147
167
  def unload_loras(self):
148
168
  self.dit.unload_loras()
169
+ if self.dit2 is not None:
170
+ self.dit2.unload_loras()
149
171
  self.text_encoder.unload_loras()
150
172
 
151
173
  def get_default_fps(self) -> int:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev19
3
+ Version: 0.6.1.dev21
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -139,15 +139,15 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
139
139
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
140
140
  diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
141
141
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
142
- diffsynth_engine/pipelines/base.py,sha256=B6Md10eeAK4itILjx3biRCFwYk2usgSv7v2V9vd4fjA,14842
143
- diffsynth_engine/pipelines/flux_image.py,sha256=Dpy8AkwywuLAhvJ6cjg5TgzhSUgFQtv6p2JTTkzUHbo,50919
142
+ diffsynth_engine/pipelines/base.py,sha256=BWW7LW0E2qwu8G-6bP3nmeO7VCQxC8srOo8tE4aKA4o,14993
143
+ diffsynth_engine/pipelines/flux_image.py,sha256=vJKvnYmeeQVX2O1Zjtm4NLrltBp66VSZ-KjAUqJ8zJ8,50936
144
144
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
145
145
  diffsynth_engine/pipelines/qwen_image.py,sha256=jt4rg-U5qWsFD0kUeDwKzgIiTAC80Cj8aq1YQOR1_-k,33052
146
146
  diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
147
147
  diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
148
148
  diffsynth_engine/pipelines/utils.py,sha256=lk7sFGEk-fGjgadLpwwppHKG-yZ0RC-4ZmHW7pRRe8A,473
149
149
  diffsynth_engine/pipelines/wan_s2v.py,sha256=3Lkdwf5CYH2fyiD2XeZIqHUfjThsNKV9F_tQXQ-7uoU,29559
150
- diffsynth_engine/pipelines/wan_video.py,sha256=x4xnP_4VAwGW04Ja78eecfLqyzMnqdgO1J9cK-DZpv4,28173
150
+ diffsynth_engine/pipelines/wan_video.py,sha256=CF8098TIvhYTrrdfuFR7K4GpgFUezONROFJG2LL7wQk,29151
151
151
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
152
152
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
153
153
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -185,8 +185,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
185
185
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
186
186
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
187
187
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
188
- diffsynth_engine-0.6.1.dev19.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
189
- diffsynth_engine-0.6.1.dev19.dist-info/METADATA,sha256=KQ9a1ITP4r5RnWNUKEGJnnt5dduwknR3rCU2K5ETBC4,1164
190
- diffsynth_engine-0.6.1.dev19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
191
- diffsynth_engine-0.6.1.dev19.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
192
- diffsynth_engine-0.6.1.dev19.dist-info/RECORD,,
188
+ diffsynth_engine-0.6.1.dev21.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
189
+ diffsynth_engine-0.6.1.dev21.dist-info/METADATA,sha256=tdKUjrwahEQ72SA-YSPu8LsaswLKJuDrjEZI_6nYySM,1164
190
+ diffsynth_engine-0.6.1.dev21.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
191
+ diffsynth_engine-0.6.1.dev21.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
192
+ diffsynth_engine-0.6.1.dev21.dist-info/RECORD,,