diffsynth-engine 0.6.1.dev19__py3-none-any.whl → 0.6.1.dev21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/pipelines/base.py +6 -2
- diffsynth_engine/pipelines/flux_image.py +1 -1
- diffsynth_engine/pipelines/wan_video.py +24 -2
- {diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/RECORD +8 -8
- {diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
import torch
|
|
3
3
|
import numpy as np
|
|
4
4
|
from einops import rearrange
|
|
5
|
-
from typing import Dict, List, Tuple, Union
|
|
5
|
+
from typing import Dict, List, Tuple, Union, Optional
|
|
6
6
|
from PIL import Image
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
|
|
@@ -70,7 +70,11 @@ class BasePipeline:
|
|
|
70
70
|
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
|
|
71
71
|
fused: bool = True,
|
|
72
72
|
save_original_weight: bool = False,
|
|
73
|
+
lora_converter: Optional[LoRAStateDictConverter] = None,
|
|
73
74
|
):
|
|
75
|
+
if not lora_converter:
|
|
76
|
+
lora_converter = self.lora_converter
|
|
77
|
+
|
|
74
78
|
for lora_path, lora_item in lora_list:
|
|
75
79
|
if isinstance(lora_item, float):
|
|
76
80
|
lora_scale = lora_item
|
|
@@ -86,7 +90,7 @@ class BasePipeline:
|
|
|
86
90
|
self.apply_scheduler_config(scheduler_config)
|
|
87
91
|
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
|
|
88
92
|
|
|
89
|
-
lora_state_dict =
|
|
93
|
+
lora_state_dict = lora_converter.convert(state_dict)
|
|
90
94
|
for model_name, state_dict in lora_state_dict.items():
|
|
91
95
|
model = getattr(self, model_name)
|
|
92
96
|
lora_args = []
|
|
@@ -830,7 +830,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
830
830
|
masked_image = image.clone()
|
|
831
831
|
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
|
|
832
832
|
latent = self.encode_image(masked_image)
|
|
833
|
-
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
|
|
833
|
+
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
|
|
834
834
|
mask = 1 - mask
|
|
835
835
|
latent = torch.cat([latent, mask], dim=1)
|
|
836
836
|
elif self.config.control_type == ControlType.bfl_fill:
|
|
@@ -95,8 +95,14 @@ class WanLoRAConverter(LoRAStateDictConverter):
|
|
|
95
95
|
return state_dict
|
|
96
96
|
|
|
97
97
|
|
|
98
|
+
class WanLowNoiseLoRAConverter(WanLoRAConverter):
|
|
99
|
+
def convert(self, state_dict):
|
|
100
|
+
return {"dit2": super().convert(state_dict)["dit"]}
|
|
101
|
+
|
|
102
|
+
|
|
98
103
|
class WanVideoPipeline(BasePipeline):
|
|
99
104
|
lora_converter = WanLoRAConverter()
|
|
105
|
+
low_noise_lora_converter = WanLowNoiseLoRAConverter()
|
|
100
106
|
|
|
101
107
|
def __init__(
|
|
102
108
|
self,
|
|
@@ -133,7 +139,13 @@ class WanVideoPipeline(BasePipeline):
|
|
|
133
139
|
self.image_encoder = image_encoder
|
|
134
140
|
self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
|
|
135
141
|
|
|
136
|
-
def load_loras(
|
|
142
|
+
def load_loras(
|
|
143
|
+
self,
|
|
144
|
+
lora_list: List[Tuple[str, float]],
|
|
145
|
+
fused: bool = True,
|
|
146
|
+
save_original_weight: bool = False,
|
|
147
|
+
lora_converter: Optional[WanLoRAConverter] = None
|
|
148
|
+
):
|
|
137
149
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
138
150
|
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
139
151
|
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
@@ -142,10 +154,20 @@ class WanVideoPipeline(BasePipeline):
|
|
|
142
154
|
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
143
155
|
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
|
|
144
156
|
)
|
|
145
|
-
super().load_loras(lora_list, fused, save_original_weight)
|
|
157
|
+
super().load_loras(lora_list, fused, save_original_weight, lora_converter)
|
|
158
|
+
|
|
159
|
+
def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
160
|
+
assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
|
|
161
|
+
self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
|
|
162
|
+
|
|
163
|
+
def load_loras_high_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
164
|
+
assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
|
|
165
|
+
self.load_loras(lora_list, fused, save_original_weight)
|
|
146
166
|
|
|
147
167
|
def unload_loras(self):
|
|
148
168
|
self.dit.unload_loras()
|
|
169
|
+
if self.dit2 is not None:
|
|
170
|
+
self.dit2.unload_loras()
|
|
149
171
|
self.text_encoder.unload_loras()
|
|
150
172
|
|
|
151
173
|
def get_default_fps(self) -> int:
|
|
@@ -139,15 +139,15 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
|
|
|
139
139
|
diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
|
|
140
140
|
diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
|
|
141
141
|
diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
|
|
142
|
-
diffsynth_engine/pipelines/base.py,sha256=
|
|
143
|
-
diffsynth_engine/pipelines/flux_image.py,sha256=
|
|
142
|
+
diffsynth_engine/pipelines/base.py,sha256=BWW7LW0E2qwu8G-6bP3nmeO7VCQxC8srOo8tE4aKA4o,14993
|
|
143
|
+
diffsynth_engine/pipelines/flux_image.py,sha256=vJKvnYmeeQVX2O1Zjtm4NLrltBp66VSZ-KjAUqJ8zJ8,50936
|
|
144
144
|
diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
|
|
145
145
|
diffsynth_engine/pipelines/qwen_image.py,sha256=jt4rg-U5qWsFD0kUeDwKzgIiTAC80Cj8aq1YQOR1_-k,33052
|
|
146
146
|
diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
|
|
147
147
|
diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
|
|
148
148
|
diffsynth_engine/pipelines/utils.py,sha256=lk7sFGEk-fGjgadLpwwppHKG-yZ0RC-4ZmHW7pRRe8A,473
|
|
149
149
|
diffsynth_engine/pipelines/wan_s2v.py,sha256=3Lkdwf5CYH2fyiD2XeZIqHUfjThsNKV9F_tQXQ-7uoU,29559
|
|
150
|
-
diffsynth_engine/pipelines/wan_video.py,sha256=
|
|
150
|
+
diffsynth_engine/pipelines/wan_video.py,sha256=CF8098TIvhYTrrdfuFR7K4GpgFUezONROFJG2LL7wQk,29151
|
|
151
151
|
diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
152
152
|
diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
|
|
153
153
|
diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
|
|
@@ -185,8 +185,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
|
|
|
185
185
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
186
186
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
187
187
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
188
|
-
diffsynth_engine-0.6.1.
|
|
189
|
-
diffsynth_engine-0.6.1.
|
|
190
|
-
diffsynth_engine-0.6.1.
|
|
191
|
-
diffsynth_engine-0.6.1.
|
|
192
|
-
diffsynth_engine-0.6.1.
|
|
188
|
+
diffsynth_engine-0.6.1.dev21.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
189
|
+
diffsynth_engine-0.6.1.dev21.dist-info/METADATA,sha256=tdKUjrwahEQ72SA-YSPu8LsaswLKJuDrjEZI_6nYySM,1164
|
|
190
|
+
diffsynth_engine-0.6.1.dev21.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
191
|
+
diffsynth_engine-0.6.1.dev21.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
192
|
+
diffsynth_engine-0.6.1.dev21.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev19.dist-info → diffsynth_engine-0.6.1.dev21.dist-info}/top_level.txt
RENAMED
|
File without changes
|