diffsynth-engine 0.7.1.dev2__py3-none-any.whl → 0.7.1.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +4 -1
- diffsynth_engine/tools/qwen_image_upscaler_tool.py +386 -0
- diffsynth_engine/utils/image.py +84 -0
- {diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/RECORD +8 -7
- {diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,10 @@ class BaseScheduler:
|
|
|
19
19
|
def update_config(self, config_dict):
|
|
20
20
|
for config_name, new_value in config_dict.items():
|
|
21
21
|
if hasattr(self, config_name):
|
|
22
|
-
|
|
22
|
+
actual_value = new_value
|
|
23
|
+
if isinstance(actual_value, str) and actual_value.lower() == "none":
|
|
24
|
+
actual_value = None
|
|
25
|
+
setattr(self, config_name, actual_value)
|
|
23
26
|
|
|
24
27
|
def restore_config(self):
|
|
25
28
|
for config_name, config_value in self._stored_config.items():
|
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Literal, Optional, Dict
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from einops import rearrange, repeat
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
|
|
11
|
+
from diffsynth_engine.configs import QwenImagePipelineConfig
|
|
12
|
+
from diffsynth_engine.pipelines.qwen_image import QwenImagePipeline
|
|
13
|
+
from diffsynth_engine.models.qwen_image import QwenImageVAE
|
|
14
|
+
from diffsynth_engine.models.basic.lora import LoRALinear
|
|
15
|
+
from diffsynth_engine.models.qwen_image.qwen_image_dit import QwenImageTransformerBlock, QwenEmbedRope
|
|
16
|
+
from diffsynth_engine.utils import logging
|
|
17
|
+
from diffsynth_engine.utils.loader import load_file
|
|
18
|
+
from diffsynth_engine.utils.download import fetch_model
|
|
19
|
+
from diffsynth_engine.utils.image import adain_color_fix, wavelet_color_fix
|
|
20
|
+
|
|
21
|
+
logger = logging.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@contextmanager
|
|
25
|
+
def odtsr_forward():
|
|
26
|
+
"""
|
|
27
|
+
Context manager for ODTSR forward pass optimization.
|
|
28
|
+
|
|
29
|
+
Replaces two methods:
|
|
30
|
+
1. LoRALinear.forward - to support batch CFG with dual outputs
|
|
31
|
+
2. QwenImageTransformerBlock._modulate - optimized version without repeat_interleave
|
|
32
|
+
"""
|
|
33
|
+
original_lora_forward = LoRALinear.forward
|
|
34
|
+
original_modulate = QwenImageTransformerBlock._modulate
|
|
35
|
+
original_rope_forward = QwenEmbedRope.forward
|
|
36
|
+
|
|
37
|
+
def lora_batch_cfg_forward(self, x):
|
|
38
|
+
y = nn.Linear.forward(self, x)
|
|
39
|
+
if len(self._lora_dict) < 1:
|
|
40
|
+
return y
|
|
41
|
+
if x.ndim == 2:
|
|
42
|
+
y2 = y.clone()
|
|
43
|
+
for name, lora in self._lora_dict.items():
|
|
44
|
+
y2 += lora(x)
|
|
45
|
+
return torch.stack([y, y2], dim=1)
|
|
46
|
+
else:
|
|
47
|
+
L2 = x.shape[1]
|
|
48
|
+
L = L2 // 2
|
|
49
|
+
x2 = x[:, L:, :]
|
|
50
|
+
for name, lora in self._lora_dict.items():
|
|
51
|
+
y[:, L:] += lora(x2)
|
|
52
|
+
return y
|
|
53
|
+
|
|
54
|
+
def optimized_rope_forward(self, video_fhw, txt_length, device):
|
|
55
|
+
if self.pos_freqs.device != device:
|
|
56
|
+
self.pos_freqs = self.pos_freqs.to(device)
|
|
57
|
+
self.neg_freqs = self.neg_freqs.to(device)
|
|
58
|
+
|
|
59
|
+
vid_freqs = []
|
|
60
|
+
max_vid_index = 0
|
|
61
|
+
idx = 0
|
|
62
|
+
for fhw in video_fhw:
|
|
63
|
+
frame, height, width = fhw
|
|
64
|
+
rope_key = f"{idx}_{height}_{width}"
|
|
65
|
+
|
|
66
|
+
if rope_key not in self.rope_cache:
|
|
67
|
+
seq_lens = frame * height * width
|
|
68
|
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
69
|
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
70
|
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
71
|
+
if self.scale_rope:
|
|
72
|
+
freqs_height = torch.cat(
|
|
73
|
+
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
74
|
+
)
|
|
75
|
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
76
|
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
77
|
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
81
|
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
82
|
+
|
|
83
|
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
84
|
+
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
|
85
|
+
vid_freqs.append(self.rope_cache[rope_key])
|
|
86
|
+
if self.scale_rope:
|
|
87
|
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
88
|
+
else:
|
|
89
|
+
max_vid_index = max(height, width, max_vid_index)
|
|
90
|
+
|
|
91
|
+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...]
|
|
92
|
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
93
|
+
|
|
94
|
+
return vid_freqs, txt_freqs
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def optimized_modulate(self, x, mod_params, index=None):
|
|
98
|
+
if mod_params.ndim == 2:
|
|
99
|
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
|
100
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
101
|
+
else:
|
|
102
|
+
B, L2, C = x.shape
|
|
103
|
+
L = L2 // 2
|
|
104
|
+
shift, scale, gate = mod_params.chunk(3, dim=-1) # Each: [B, 2, dim]
|
|
105
|
+
|
|
106
|
+
result = torch.empty_like(x)
|
|
107
|
+
gate_result = torch.empty(B, L2, gate.shape[-1], dtype=x.dtype, device=x.device)
|
|
108
|
+
|
|
109
|
+
result[:, :L] = x[:, :L] * (1 + scale[:, 0:1]) + shift[:, 0:1]
|
|
110
|
+
gate_result[:, :L] = gate[:, 0:1].expand(-1, L, -1)
|
|
111
|
+
|
|
112
|
+
result[:, L:] = x[:, L:] * (1 + scale[:, 1:2]) + shift[:, 1:2]
|
|
113
|
+
gate_result[:, L:] = gate[:, 1:2].expand(-1, L, -1)
|
|
114
|
+
|
|
115
|
+
return result, gate_result
|
|
116
|
+
|
|
117
|
+
LoRALinear.forward = lora_batch_cfg_forward
|
|
118
|
+
QwenImageTransformerBlock._modulate = optimized_modulate
|
|
119
|
+
QwenEmbedRope.forward = optimized_rope_forward
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
yield
|
|
123
|
+
finally:
|
|
124
|
+
LoRALinear.forward = original_lora_forward
|
|
125
|
+
QwenImageTransformerBlock._modulate = original_modulate
|
|
126
|
+
QwenEmbedRope.forward = original_rope_forward
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class QwenImageUpscalerTool:
|
|
130
|
+
"""
|
|
131
|
+
Tool for ODTSR (One-step Diffusion Transformer Super Resolution) image upscaling.
|
|
132
|
+
https://huggingface.co/double8fun/ODTSR
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
pipeline: QwenImagePipeline,
|
|
138
|
+
odtsr_weight_path: Optional[str] = None,
|
|
139
|
+
):
|
|
140
|
+
self.pipe = pipeline
|
|
141
|
+
self.device = self.pipe.device
|
|
142
|
+
self.dtype = self.pipe.dtype
|
|
143
|
+
|
|
144
|
+
# to avoid "small grid" artifacts in generated images
|
|
145
|
+
self._convert_dit_part_linear_weight()
|
|
146
|
+
|
|
147
|
+
if not odtsr_weight_path:
|
|
148
|
+
odtsr_weight_path = fetch_model("muse/ODTSR", revision="master", path="weight.safetensors")
|
|
149
|
+
odtsr_state_dict = load_file(odtsr_weight_path)
|
|
150
|
+
lora_state_dict = self._convert_odtsr_lora(odtsr_state_dict)
|
|
151
|
+
lora_state_dict_list = [(lora_state_dict, 1.0, odtsr_weight_path)]
|
|
152
|
+
self.pipe._load_lora_state_dicts(lora_state_dict_list, fused=False, save_original_weight=False)
|
|
153
|
+
|
|
154
|
+
self.new_vae = deepcopy(self.pipe.vae)
|
|
155
|
+
self._load_vae_encoder_weights(odtsr_state_dict)
|
|
156
|
+
|
|
157
|
+
sigmas = torch.linspace(1.0, 0.0, 1000 + 1)[:-1]
|
|
158
|
+
mu = 0.8
|
|
159
|
+
shift_terminal = 0.02
|
|
160
|
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
|
161
|
+
one_minus_sigmas = 1 - sigmas
|
|
162
|
+
scale_factor = one_minus_sigmas[-1] / (1 - shift_terminal)
|
|
163
|
+
self.sigmas = 1 - (one_minus_sigmas / scale_factor)
|
|
164
|
+
self.sigmas = self.sigmas.to(device=self.device)
|
|
165
|
+
self.timesteps = self.sigmas * self.pipe.noise_scheduler.num_train_timesteps
|
|
166
|
+
self.timesteps = self.timesteps.to(device=self.device)
|
|
167
|
+
self.start_timestep = 750
|
|
168
|
+
self.fixed_timestep = self.timesteps[self.start_timestep].to(device=self.device)
|
|
169
|
+
self.one_step_sigma = self.sigmas[self.start_timestep].to(device=self.device)
|
|
170
|
+
|
|
171
|
+
self.prompt = "High Contrast, hyper detailed photo, 2k UHD"
|
|
172
|
+
self.prompt_emb, self.prompt_emb_mask = self.pipe.encode_prompt(self.prompt, 1, 4096)
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def from_pretrained(
|
|
176
|
+
cls,
|
|
177
|
+
qwen_model_path: str,
|
|
178
|
+
odtsr_weight_path: Optional[str] = None,
|
|
179
|
+
device: str = "cuda",
|
|
180
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
181
|
+
):
|
|
182
|
+
config = QwenImagePipelineConfig(
|
|
183
|
+
model_path=qwen_model_path,
|
|
184
|
+
model_dtype=dtype,
|
|
185
|
+
device=device,
|
|
186
|
+
load_encoder=True,
|
|
187
|
+
)
|
|
188
|
+
pipe = QwenImagePipeline.from_pretrained(config)
|
|
189
|
+
return cls(pipe, odtsr_weight_path)
|
|
190
|
+
|
|
191
|
+
def _convert_dit_part_linear_weight(self):
|
|
192
|
+
"""
|
|
193
|
+
Perform dtype conversion on weights of specific Linear layers in the DIT model.
|
|
194
|
+
|
|
195
|
+
This is an important trick: for Linear layers NOT in the patterns list, convert their weights
|
|
196
|
+
to float8_e4m3fn first, then convert back to the original dtype (typically bfloat16). This operation
|
|
197
|
+
matches the weight processing method used during training to avoid "small grid" artifacts in generated images.
|
|
198
|
+
|
|
199
|
+
Layers in the patterns list (such as LoRA-related layers) are skipped and their original weights remain unchanged.
|
|
200
|
+
"""
|
|
201
|
+
patterns = [
|
|
202
|
+
"img_in",
|
|
203
|
+
"img_mod.1",
|
|
204
|
+
"attn.to_q",
|
|
205
|
+
"attn.to_k",
|
|
206
|
+
"attn.to_v",
|
|
207
|
+
"to_out",
|
|
208
|
+
"img_mlp.net.0.proj",
|
|
209
|
+
"img_mlp.net.2",
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
def _convert_weight(parent: nn.Module, name_prefix: str = ""):
|
|
213
|
+
for name, module in list(parent.named_children()):
|
|
214
|
+
full_name = f"{name_prefix}{name}"
|
|
215
|
+
if isinstance(module, torch.nn.Linear):
|
|
216
|
+
if not any(p in full_name for p in patterns):
|
|
217
|
+
origin_dtype = module.weight.data.dtype
|
|
218
|
+
module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
|
|
219
|
+
module.weight.data = module.weight.data.to(origin_dtype)
|
|
220
|
+
if module.bias is not None:
|
|
221
|
+
module.bias.data = module.bias.data.to(torch.float8_e4m3fn)
|
|
222
|
+
module.bias.data = module.bias.data.to(origin_dtype)
|
|
223
|
+
else:
|
|
224
|
+
_convert_weight(module, name_prefix=full_name + ".")
|
|
225
|
+
|
|
226
|
+
_convert_weight(self.pipe.dit)
|
|
227
|
+
|
|
228
|
+
def _convert_odtsr_lora(self, odtsr_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
229
|
+
state_dict = {}
|
|
230
|
+
for key, param in odtsr_state_dict.items():
|
|
231
|
+
if "lora_A2" in key:
|
|
232
|
+
lora_b_key = key.replace("lora_A2", "lora_B2")
|
|
233
|
+
lora_b_param = odtsr_state_dict[lora_b_key]
|
|
234
|
+
|
|
235
|
+
lora_a_key = key.replace("lora_A2", "lora_A").replace("pipe.dit.", "")
|
|
236
|
+
lora_b_key = lora_b_key.replace("lora_B2", "lora_B").replace("pipe.dit.", "")
|
|
237
|
+
state_dict[lora_a_key] = param
|
|
238
|
+
state_dict[lora_b_key] = lora_b_param
|
|
239
|
+
|
|
240
|
+
return state_dict
|
|
241
|
+
|
|
242
|
+
def _load_vae_encoder_weights(self, state_dict: Dict[str, torch.Tensor]):
|
|
243
|
+
try:
|
|
244
|
+
vae_state_dict = {}
|
|
245
|
+
for k, v in state_dict.items():
|
|
246
|
+
if 'pipe.new_vae.' in k:
|
|
247
|
+
new_key = k.replace('pipe.new_vae.', '')
|
|
248
|
+
vae_state_dict[new_key] = v
|
|
249
|
+
if vae_state_dict:
|
|
250
|
+
self.new_vae.load_state_dict(vae_state_dict, strict=False)
|
|
251
|
+
logger.info(f"Loaded {len(vae_state_dict)} trained VAE encoder parameters")
|
|
252
|
+
else:
|
|
253
|
+
logger.warning(f"No 'pipe.new_vae.' weights found, using original VAE")
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.error(f"Failed to load VAE encoder weights: {e}")
|
|
256
|
+
raise e
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def add_noise(self, sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
|
260
|
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
261
|
+
sigma = self.sigmas[timestep_id]
|
|
262
|
+
sample = (1 - sigma) * sample + sigma * noise
|
|
263
|
+
return sample
|
|
264
|
+
|
|
265
|
+
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
|
|
266
|
+
image = torch.Tensor(np.array(image, dtype=np.float32))
|
|
267
|
+
image = image.to(dtype=self.dtype, device=self.device)
|
|
268
|
+
image = image * (2 / 255) - 1
|
|
269
|
+
image = repeat(image, f"H W C -> B C H W", **({"B": 1}))
|
|
270
|
+
return image
|
|
271
|
+
|
|
272
|
+
def _prepare_condition_latents(self, image: Image.Image, vae: QwenImageVAE, vae_tiled: bool) -> torch.Tensor:
|
|
273
|
+
image_tensor = self.preprocess_image(image).to(dtype=self.pipe.config.vae_dtype)
|
|
274
|
+
image_tensor = image_tensor.unsqueeze(2)
|
|
275
|
+
|
|
276
|
+
latents = vae.encode(
|
|
277
|
+
image_tensor,
|
|
278
|
+
device=self.device,
|
|
279
|
+
tiled=vae_tiled,
|
|
280
|
+
tile_size=self.pipe.vae_tile_size,
|
|
281
|
+
tile_stride=self.pipe.vae_tile_stride,
|
|
282
|
+
)
|
|
283
|
+
latents = latents.squeeze(2).to(device=self.device, dtype=self.dtype)
|
|
284
|
+
return latents
|
|
285
|
+
|
|
286
|
+
def _single_step_denoise(
|
|
287
|
+
self,
|
|
288
|
+
latents: torch.Tensor,
|
|
289
|
+
image_latents: torch.Tensor,
|
|
290
|
+
noise: torch.Tensor,
|
|
291
|
+
prompt_emb: torch.Tensor,
|
|
292
|
+
prompt_emb_mask: torch.Tensor,
|
|
293
|
+
fidelity: float,
|
|
294
|
+
) -> torch.Tensor:
|
|
295
|
+
fidelity_timestep_id = int(self.start_timestep + fidelity * (1000 - self.start_timestep) + 0.5)
|
|
296
|
+
if fidelity_timestep_id != 1000:
|
|
297
|
+
fidelity_timestep = self.timesteps[fidelity_timestep_id].to(device=self.device)
|
|
298
|
+
image_latents = self.add_noise(image_latents, noise, fidelity_timestep)
|
|
299
|
+
|
|
300
|
+
latents = self.add_noise(latents, noise, self.fixed_timestep)
|
|
301
|
+
|
|
302
|
+
with odtsr_forward():
|
|
303
|
+
noise_pred = self.pipe.predict_noise_with_cfg(
|
|
304
|
+
latents=latents,
|
|
305
|
+
image_latents=[image_latents],
|
|
306
|
+
timestep=self.fixed_timestep.unsqueeze(0),
|
|
307
|
+
prompt_emb=prompt_emb,
|
|
308
|
+
prompt_emb_mask=prompt_emb_mask,
|
|
309
|
+
negative_prompt_emb=None,
|
|
310
|
+
negative_prompt_emb_mask=None,
|
|
311
|
+
context_latents=None,
|
|
312
|
+
entity_prompt_embs=None,
|
|
313
|
+
entity_prompt_emb_masks=None,
|
|
314
|
+
negative_entity_prompt_embs=None,
|
|
315
|
+
negative_entity_prompt_emb_masks=None,
|
|
316
|
+
entity_masks=None,
|
|
317
|
+
cfg_scale=1.0,
|
|
318
|
+
batch_cfg=self.pipe.config.batch_cfg,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
denoised = latents + (0 - self.one_step_sigma) * noise_pred
|
|
322
|
+
return denoised
|
|
323
|
+
|
|
324
|
+
@torch.no_grad()
|
|
325
|
+
def __call__(
|
|
326
|
+
self,
|
|
327
|
+
image: Image.Image,
|
|
328
|
+
scale: int = 2,
|
|
329
|
+
prompt: str = "High Contrast, hyper detailed photo, 2k UHD",
|
|
330
|
+
fidelity: float = 1.0,
|
|
331
|
+
align_method: Literal["none", "adain", "wavelet"] = "none",
|
|
332
|
+
) -> Image.Image:
|
|
333
|
+
width, height = image.size
|
|
334
|
+
target_width, target_height = width * scale, height * scale
|
|
335
|
+
target_width_round = target_width // 16 * 16
|
|
336
|
+
target_height_round = target_height // 16 * 16
|
|
337
|
+
logger.info(f"Upscaling image from {width}x{height} to {target_width}x{target_height}")
|
|
338
|
+
vae_tiled = (target_width_round * target_height_round > 2048 * 2048)
|
|
339
|
+
|
|
340
|
+
resized_image = image.resize((target_width_round, target_height_round), Image.BICUBIC)
|
|
341
|
+
|
|
342
|
+
condition_latents = self._prepare_condition_latents(resized_image, self.pipe.vae, vae_tiled)
|
|
343
|
+
latents = self._prepare_condition_latents(resized_image, self.new_vae, vae_tiled)
|
|
344
|
+
|
|
345
|
+
noise = self.pipe.generate_noise(
|
|
346
|
+
(1, 16, target_height_round // 8, target_width_round // 8),
|
|
347
|
+
seed=42,
|
|
348
|
+
device=self.device,
|
|
349
|
+
dtype=self.dtype
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
prompt_emb, prompt_emb_mask = self.prompt_emb, self.prompt_emb_mask
|
|
353
|
+
if prompt != self.prompt:
|
|
354
|
+
prompt_emb, prompt_emb_mask = self.pipe.encode_prompt(prompt, 1, 4096)
|
|
355
|
+
|
|
356
|
+
denoised_latents = self._single_step_denoise(
|
|
357
|
+
latents=latents,
|
|
358
|
+
noise=noise,
|
|
359
|
+
image_latents=condition_latents,
|
|
360
|
+
prompt_emb=prompt_emb,
|
|
361
|
+
prompt_emb_mask=prompt_emb_mask,
|
|
362
|
+
fidelity=fidelity,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Decode
|
|
366
|
+
denoised_latents = rearrange(denoised_latents, "B C H W -> B C 1 H W")
|
|
367
|
+
vae_output = rearrange(
|
|
368
|
+
self.pipe.vae.decode(
|
|
369
|
+
denoised_latents.to(self.pipe.vae.model.encoder.conv1.weight.dtype),
|
|
370
|
+
device=self.pipe.vae.model.encoder.conv1.weight.device,
|
|
371
|
+
tiled=vae_tiled,
|
|
372
|
+
tile_size=self.pipe.vae_tile_size,
|
|
373
|
+
tile_stride=self.pipe.vae_tile_stride,
|
|
374
|
+
)[0],
|
|
375
|
+
"C B H W -> B C H W",
|
|
376
|
+
)
|
|
377
|
+
result_image = self.pipe.vae_output_to_image(vae_output)
|
|
378
|
+
self.pipe.model_lifecycle_finish(["vae"])
|
|
379
|
+
|
|
380
|
+
if align_method == "adain":
|
|
381
|
+
result_image = adain_color_fix(target=result_image, source=resized_image)
|
|
382
|
+
elif align_method == "wavelet":
|
|
383
|
+
result_image = wavelet_color_fix(target=result_image, source=resized_image)
|
|
384
|
+
|
|
385
|
+
result_image = result_image.resize((target_width, target_height), Image.BICUBIC)
|
|
386
|
+
return result_image
|
diffsynth_engine/utils/image.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torchvision import transforms
|
|
3
|
+
from torchvision.transforms import ToTensor, ToPILImage
|
|
3
4
|
import numpy as np
|
|
4
5
|
import math
|
|
5
6
|
from PIL import Image
|
|
6
7
|
from enum import Enum
|
|
7
8
|
from typing import List, Tuple, Optional
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from torch.nn import functional as F
|
|
8
11
|
|
|
9
12
|
from diffsynth_engine.utils import logging
|
|
10
13
|
|
|
@@ -243,3 +246,84 @@ def _need_rescale_pil_conversion(image: np.ndarray) -> bool:
|
|
|
243
246
|
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
|
244
247
|
)
|
|
245
248
|
return do_rescale
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# --------------------------------------------------------------------------------
|
|
252
|
+
# Color Alignment Functions
|
|
253
|
+
# Based on Li Yi's implementation: https://github.com/pkuliyi2015/sd-webui-stablesr
|
|
254
|
+
# --------------------------------------------------------------------------------
|
|
255
|
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
|
256
|
+
size = feat.size()
|
|
257
|
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
|
258
|
+
b, c = size[:2]
|
|
259
|
+
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
|
260
|
+
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
|
261
|
+
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
|
262
|
+
return feat_mean, feat_std
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def adaptive_instance_normalization(content_feat: Tensor, style_feat: Tensor):
|
|
266
|
+
size = content_feat.size()
|
|
267
|
+
style_mean, style_std = calc_mean_std(style_feat)
|
|
268
|
+
content_mean, content_std = calc_mean_std(content_feat)
|
|
269
|
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
|
270
|
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def wavelet_blur(image: Tensor, radius: int):
|
|
274
|
+
kernel_vals = [
|
|
275
|
+
[0.0625, 0.125, 0.0625],
|
|
276
|
+
[0.125, 0.25, 0.125],
|
|
277
|
+
[0.0625, 0.125, 0.0625],
|
|
278
|
+
]
|
|
279
|
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
|
280
|
+
kernel = kernel[None, None]
|
|
281
|
+
kernel = kernel.repeat(3, 1, 1, 1)
|
|
282
|
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
|
283
|
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
|
284
|
+
return output
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def wavelet_decomposition(image: Tensor, levels=5):
|
|
288
|
+
high_freq = torch.zeros_like(image)
|
|
289
|
+
for i in range(levels):
|
|
290
|
+
radius = 2 ** i
|
|
291
|
+
low_freq = wavelet_blur(image, radius)
|
|
292
|
+
high_freq += (image - low_freq)
|
|
293
|
+
image = low_freq
|
|
294
|
+
|
|
295
|
+
return high_freq, low_freq
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor):
|
|
299
|
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
|
300
|
+
del content_low_freq
|
|
301
|
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
|
302
|
+
del style_high_freq
|
|
303
|
+
return content_high_freq + style_low_freq
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def adain_color_fix(target: Image.Image, source: Image.Image) -> Image.Image:
|
|
307
|
+
to_tensor = ToTensor()
|
|
308
|
+
target_tensor = to_tensor(target).unsqueeze(0)
|
|
309
|
+
source_tensor = to_tensor(source).unsqueeze(0)
|
|
310
|
+
|
|
311
|
+
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
|
312
|
+
|
|
313
|
+
to_image = ToPILImage()
|
|
314
|
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
|
315
|
+
|
|
316
|
+
return result_image
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def wavelet_color_fix(target: Image.Image, source: Image.Image) -> Image.Image:
|
|
320
|
+
to_tensor = ToTensor()
|
|
321
|
+
target_tensor = to_tensor(target).unsqueeze(0)
|
|
322
|
+
source_tensor = to_tensor(source).unsqueeze(0)
|
|
323
|
+
|
|
324
|
+
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
|
325
|
+
|
|
326
|
+
to_image = ToPILImage()
|
|
327
|
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
|
328
|
+
|
|
329
|
+
return result_image
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
diffsynth_engine/__init__.py,sha256=lzUI6r47i2CCUiSIwi1IK502TL89ZG7h1yNwmM1eFvI,2588
|
|
2
2
|
diffsynth_engine/algorithm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
diffsynth_engine/algorithm/noise_scheduler/__init__.py,sha256=YvcwE2tCNua-OAX9GEPm0EXsINNWH4XvJMNZb-uaZMM,745
|
|
4
|
-
diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py,sha256=
|
|
4
|
+
diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py,sha256=WflR4KGZhbbsoTnEQhpPNR2FfJhTQqdU27A8tBN58P8,988
|
|
5
5
|
diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py,sha256=ivBtxk1P_ERGxptqzYCnsguwL9aScJ5hpAgF7xgtR2I,213
|
|
6
6
|
diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py,sha256=atw0CPS3TnitILpy78T6-YdDQMcBvTEHJloZzjtWqvM,1161
|
|
7
7
|
diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py,sha256=cX_18RlvWtiEIcSCUflGipJdooNo9BZUHTQEm8Ltnfg,1108
|
|
@@ -186,6 +186,7 @@ diffsynth_engine/tools/flux_inpainting_tool.py,sha256=qHsYKUG20A19ujRdocpIPC4a_H
|
|
|
186
186
|
diffsynth_engine/tools/flux_outpainting_tool.py,sha256=ff4qUj2mMYW6GMts7ifnJG7Rth55pfuggopRCyAXwJ8,3894
|
|
187
187
|
diffsynth_engine/tools/flux_reference_tool.py,sha256=6v0NRZPsDEHFlPruO-ZJTB4rYWxKVAlmnYEeandD3r8,4723
|
|
188
188
|
diffsynth_engine/tools/flux_replace_tool.py,sha256=AOyEGxHsaNwpTS2VChAieIfECgMxlKsRw0lWPm1k9C0,4627
|
|
189
|
+
diffsynth_engine/tools/qwen_image_upscaler_tool.py,sha256=GMhV7Sphg2zgkOJhnZeLVWQJQv1d6QnOuQZXEvHgIyI,16222
|
|
189
190
|
diffsynth_engine/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
190
191
|
diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6gH4,2101
|
|
191
192
|
diffsynth_engine/utils/constants.py,sha256=Tsn3EAByfZra-nGcx0NEcP9nWTPKaDGdatosE3BuPGE,3846
|
|
@@ -194,7 +195,7 @@ diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I
|
|
|
194
195
|
diffsynth_engine/utils/flag.py,sha256=Ubm7FF0vHG197bmJGEplp4XauBlUaQVv-zr-w6VyEIM,2493
|
|
195
196
|
diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
|
|
196
197
|
diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
|
|
197
|
-
diffsynth_engine/utils/image.py,sha256=
|
|
198
|
+
diffsynth_engine/utils/image.py,sha256=jqx-UKfdc2YRBtHoL-RP2M8yce_0h2rTIJgf6mux-aU,12695
|
|
198
199
|
diffsynth_engine/utils/loader.py,sha256=usIr2nUMgPxEdtEND6kboaST3ZUVr0PVWwm2sK-HXe8,1871
|
|
199
200
|
diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPUI,1601
|
|
200
201
|
diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
|
|
@@ -208,8 +209,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
|
|
|
208
209
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
209
210
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
210
211
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
211
|
-
diffsynth_engine-0.7.1.
|
|
212
|
-
diffsynth_engine-0.7.1.
|
|
213
|
-
diffsynth_engine-0.7.1.
|
|
214
|
-
diffsynth_engine-0.7.1.
|
|
215
|
-
diffsynth_engine-0.7.1.
|
|
212
|
+
diffsynth_engine-0.7.1.dev4.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
213
|
+
diffsynth_engine-0.7.1.dev4.dist-info/METADATA,sha256=3n5TgI6s2eg_hPbz-ihTt9XGRjFl3TNeODyn_CxjZgg,1163
|
|
214
|
+
diffsynth_engine-0.7.1.dev4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
215
|
+
diffsynth_engine-0.7.1.dev4.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
216
|
+
diffsynth_engine-0.7.1.dev4.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.7.1.dev2.dist-info → diffsynth_engine-0.7.1.dev4.dist-info}/top_level.txt
RENAMED
|
File without changes
|