diffsynth-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +25 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/models/__init__.py +0 -0
- diffsynth_engine/models/base.py +55 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +137 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +393 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +504 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/attention.py +200 -0
- diffsynth_engine/models/wan/wan_dit.py +431 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +17 -0
- diffsynth_engine/pipelines/base.py +216 -0
- diffsynth_engine/pipelines/flux_image.py +548 -0
- diffsynth_engine/pipelines/sd_image.py +386 -0
- diffsynth_engine/pipelines/sdxl_image.py +430 -0
- diffsynth_engine/pipelines/wan_video.py +481 -0
- diffsynth_engine/tokenizers/__init__.py +4 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +79 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +139 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +14 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +191 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
- diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
- diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
import torch
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Callable, Dict, Optional, List, Tuple
|
|
6
|
+
from safetensors.torch import load_file
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from diffsynth_engine.models.base import LoRAStateDictConverter, split_suffix
|
|
11
|
+
from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d
|
|
12
|
+
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
|
|
13
|
+
from diffsynth_engine.pipelines import BasePipeline
|
|
14
|
+
from diffsynth_engine.tokenizers import CLIPTokenizer
|
|
15
|
+
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
|
|
16
|
+
from diffsynth_engine.algorithm.sampler import EulerSampler
|
|
17
|
+
from diffsynth_engine.utils.prompt import tokenize_long_prompt
|
|
18
|
+
from diffsynth_engine.utils.constants import SDXL_TOKENIZER_CONF_PATH
|
|
19
|
+
from diffsynth_engine.utils import logging
|
|
20
|
+
|
|
21
|
+
logger = logging.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
re_compiled = {}
|
|
24
|
+
re_digits = re.compile(r"\d+")
|
|
25
|
+
suffix_conversion = {
|
|
26
|
+
"attentions": {},
|
|
27
|
+
"resnets": {
|
|
28
|
+
"conv1": "in_layers_2",
|
|
29
|
+
"conv2": "out_layers_3",
|
|
30
|
+
"norm1": "in_layers_0",
|
|
31
|
+
"norm2": "out_layers_0",
|
|
32
|
+
"time_emb_proj": "emb_layers_1",
|
|
33
|
+
"conv_shortcut": "skip_connection",
|
|
34
|
+
},
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def convert_diffusers_name_to_compvis(key):
|
|
39
|
+
def match(match_list, regex_text):
|
|
40
|
+
regex = re_compiled.get(regex_text)
|
|
41
|
+
if regex is None:
|
|
42
|
+
regex = re.compile(regex_text)
|
|
43
|
+
re_compiled[regex_text] = regex
|
|
44
|
+
|
|
45
|
+
r = re.match(regex, key)
|
|
46
|
+
if not r:
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
match_list.clear()
|
|
50
|
+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
m = []
|
|
54
|
+
|
|
55
|
+
if match(m, r"lora_unet_conv_in(.*)"):
|
|
56
|
+
return f"model.diffusion_model.input_blocks.0.0{m[0]}"
|
|
57
|
+
|
|
58
|
+
if match(m, r"lora_unet_conv_out(.*)"):
|
|
59
|
+
return f"model.diffusion_model.out.2{m[0]}"
|
|
60
|
+
|
|
61
|
+
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
|
62
|
+
return f"model.diffusion_model.time_embed_{m[0] * 2 - 2}{m[1]}"
|
|
63
|
+
|
|
64
|
+
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
65
|
+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
66
|
+
return f"model.diffusion_model.input_blocks.{1 + m[0] * 3 + m[2]}.{1 if m[1] == 'attentions' else 0}.{suffix}"
|
|
67
|
+
|
|
68
|
+
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
|
69
|
+
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
|
70
|
+
return f"model.diffusion_model.middle_block.{1 if m[0] == 'attentions' else m[1] * 2}.{suffix}"
|
|
71
|
+
|
|
72
|
+
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
73
|
+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
74
|
+
return f"model.diffusion_model.output_blocks.{m[0] * 3 + m[2]}.{1 if m[1] == 'attentions' else 0}.{suffix}"
|
|
75
|
+
|
|
76
|
+
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
|
77
|
+
return f"model.diffusion_model.input_blocks.{3 + m[0] * 3}.0.op"
|
|
78
|
+
|
|
79
|
+
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
|
80
|
+
return f"model.diffusion_model.output_blocks.{2 + m[0] * 3}.{2 if m[0] > 0 else 1}.conv"
|
|
81
|
+
return key
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class SDModelConfig:
|
|
86
|
+
unet_path: str | os.PathLike
|
|
87
|
+
clip_path: Optional[str | os.PathLike] = None
|
|
88
|
+
vae_path: Optional[str | os.PathLike] = None
|
|
89
|
+
|
|
90
|
+
unet_dtype: torch.dtype = torch.float16
|
|
91
|
+
clip_dtype: torch.dtype = torch.float16
|
|
92
|
+
vae_dtype: torch.dtype = torch.float32
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SDLoRAConverter(LoRAStateDictConverter):
|
|
96
|
+
def _replace_kohya_te_key(self, key):
|
|
97
|
+
key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
|
|
98
|
+
key = re.sub(r"(\d+)_", r"\1.", key)
|
|
99
|
+
key = key.replace("mlp_fc1", "fc1")
|
|
100
|
+
key = key.replace("mlp_fc2", "fc2")
|
|
101
|
+
key = key.replace("self_attn_q_proj", "attn.to_q")
|
|
102
|
+
key = key.replace("self_attn_k_proj", "attn.to_k")
|
|
103
|
+
key = key.replace("self_attn_v_proj", "attn.to_v")
|
|
104
|
+
key = key.replace("self_attn_out_proj", "attn.to_out")
|
|
105
|
+
return key
|
|
106
|
+
|
|
107
|
+
def _replace_kohya_unet_key(self, key):
|
|
108
|
+
rename_dict = sd_unet_config["civitai"]["rename_dict"]
|
|
109
|
+
key = convert_diffusers_name_to_compvis(key)
|
|
110
|
+
key = re.sub(r"(\d+)_", r"\1.", key)
|
|
111
|
+
key = re.sub(r"_(\d+)", r".\1", key)
|
|
112
|
+
key = key.replace("ff_net", "ff.net")
|
|
113
|
+
name, suffix = split_suffix(key)
|
|
114
|
+
if name not in rename_dict:
|
|
115
|
+
raise ValueError(f"Unsupported key: {key}, name: {name}, suffix: {suffix}")
|
|
116
|
+
key = rename_dict[name] + suffix
|
|
117
|
+
return key
|
|
118
|
+
|
|
119
|
+
def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
120
|
+
unet_dict = {}
|
|
121
|
+
te_dict = {}
|
|
122
|
+
for key, param in lora_state_dict.items():
|
|
123
|
+
lora_args = {}
|
|
124
|
+
if ".alpha" not in key:
|
|
125
|
+
continue
|
|
126
|
+
lora_args["alpha"] = param
|
|
127
|
+
lora_args["up"] = lora_state_dict[key.replace(".alpha", ".lora_up.weight")].squeeze()
|
|
128
|
+
lora_args["down"] = lora_state_dict[key.replace(".alpha", ".lora_down.weight")].squeeze()
|
|
129
|
+
lora_args["rank"] = lora_args["up"].shape[1]
|
|
130
|
+
key = key.replace(".alpha", "")
|
|
131
|
+
if "lora_unet" in key:
|
|
132
|
+
key = self._replace_kohya_unet_key(key)
|
|
133
|
+
unet_dict[key] = lora_args
|
|
134
|
+
elif "lora_te" in key:
|
|
135
|
+
key = self._replace_kohya_te_key(key)
|
|
136
|
+
te_dict[key] = lora_args
|
|
137
|
+
return {"unet": unet_dict, "text_encoder": te_dict}
|
|
138
|
+
|
|
139
|
+
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
140
|
+
key = list(lora_state_dict.keys())[0]
|
|
141
|
+
if "lora_te" in key or "lora_unet" in key:
|
|
142
|
+
return self._from_kohya(lora_state_dict)
|
|
143
|
+
raise ValueError(f"Unsupported key: {key}")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SDImagePipeline(BasePipeline):
|
|
147
|
+
lora_converter = SDLoRAConverter()
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
tokenizer: CLIPTokenizer,
|
|
152
|
+
text_encoder: SDTextEncoder,
|
|
153
|
+
unet: SDUNet,
|
|
154
|
+
vae_decoder: SDVAEDecoder,
|
|
155
|
+
vae_encoder: SDVAEEncoder,
|
|
156
|
+
batch_cfg: bool = True,
|
|
157
|
+
device: str = "cuda",
|
|
158
|
+
dtype: torch.dtype = torch.float16,
|
|
159
|
+
):
|
|
160
|
+
super().__init__(device=device, dtype=dtype)
|
|
161
|
+
self.noise_scheduler = ScaledLinearScheduler()
|
|
162
|
+
self.sampler = EulerSampler()
|
|
163
|
+
# models
|
|
164
|
+
self.tokenizer = tokenizer
|
|
165
|
+
self.text_encoder = text_encoder
|
|
166
|
+
self.unet = unet
|
|
167
|
+
self.vae_decoder = vae_decoder
|
|
168
|
+
self.vae_encoder = vae_encoder
|
|
169
|
+
self.batch_cfg = batch_cfg
|
|
170
|
+
self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_pretrained(
|
|
174
|
+
cls,
|
|
175
|
+
model_path_or_config: str | os.PathLike | SDModelConfig,
|
|
176
|
+
device: str = "cuda:0",
|
|
177
|
+
dtype: torch.dtype = torch.float16,
|
|
178
|
+
offload_mode: str | None = None,
|
|
179
|
+
batch_cfg: bool = True,
|
|
180
|
+
) -> "SDImagePipeline":
|
|
181
|
+
cls.validate_offload_mode(offload_mode)
|
|
182
|
+
|
|
183
|
+
if isinstance(model_path_or_config, str):
|
|
184
|
+
model_config = SDModelConfig(unet_path=model_path_or_config)
|
|
185
|
+
else:
|
|
186
|
+
model_config = model_path_or_config
|
|
187
|
+
|
|
188
|
+
logger.info(f"loading state dict from {model_config.unet_path} ...")
|
|
189
|
+
unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
|
|
190
|
+
|
|
191
|
+
if model_config.vae_path is not None:
|
|
192
|
+
logger.info(f"loading state dict from {model_config.vae_path} ...")
|
|
193
|
+
vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
|
|
194
|
+
else:
|
|
195
|
+
vae_state_dict = unet_state_dict
|
|
196
|
+
|
|
197
|
+
if model_config.clip_path is not None:
|
|
198
|
+
logger.info(f"loading state dict from {model_config.clip_path} ...")
|
|
199
|
+
clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
|
|
200
|
+
else:
|
|
201
|
+
clip_state_dict = unet_state_dict
|
|
202
|
+
|
|
203
|
+
init_device = "cpu" if offload_mode else device
|
|
204
|
+
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
205
|
+
with LoRAContext():
|
|
206
|
+
text_encoder = SDTextEncoder.from_state_dict(
|
|
207
|
+
clip_state_dict, device=init_device, dtype=model_config.clip_dtype
|
|
208
|
+
)
|
|
209
|
+
unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
|
|
210
|
+
vae_decoder = SDVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
|
|
211
|
+
vae_encoder = SDVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
|
|
212
|
+
|
|
213
|
+
pipe = cls(
|
|
214
|
+
tokenizer=tokenizer,
|
|
215
|
+
text_encoder=text_encoder,
|
|
216
|
+
unet=unet,
|
|
217
|
+
vae_decoder=vae_decoder,
|
|
218
|
+
vae_encoder=vae_encoder,
|
|
219
|
+
batch_cfg=batch_cfg,
|
|
220
|
+
device=device,
|
|
221
|
+
dtype=dtype,
|
|
222
|
+
)
|
|
223
|
+
if offload_mode == "cpu_offload":
|
|
224
|
+
pipe.enable_cpu_offload()
|
|
225
|
+
elif offload_mode == "sequential_cpu_offload":
|
|
226
|
+
pipe.enable_sequential_cpu_offload()
|
|
227
|
+
return pipe
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_state_dict(
|
|
231
|
+
cls, state_dict: Dict[str, torch.Tensor], device: str = "cuda:0", dtype: torch.dtype = torch.float16
|
|
232
|
+
) -> "SDImagePipeline":
|
|
233
|
+
raise NotImplementedError()
|
|
234
|
+
|
|
235
|
+
def denoising_model(self):
|
|
236
|
+
return self.unet
|
|
237
|
+
|
|
238
|
+
def encode_prompt(self, prompt, clip_skip):
|
|
239
|
+
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(self.device)
|
|
240
|
+
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
|
|
241
|
+
return prompt_emb
|
|
242
|
+
|
|
243
|
+
def predict_noise_with_cfg(
|
|
244
|
+
self,
|
|
245
|
+
latents: torch.Tensor,
|
|
246
|
+
timestep: torch.Tensor,
|
|
247
|
+
positive_prompt_emb: torch.Tensor,
|
|
248
|
+
negative_prompt_emb: torch.Tensor,
|
|
249
|
+
cfg_scale: float,
|
|
250
|
+
batch_cfg: bool = True,
|
|
251
|
+
):
|
|
252
|
+
if cfg_scale < 1.0:
|
|
253
|
+
return self.predict_noise(latents, timestep, positive_prompt_emb)
|
|
254
|
+
if not batch_cfg:
|
|
255
|
+
# cfg by predict noise one by one
|
|
256
|
+
positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb)
|
|
257
|
+
negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb)
|
|
258
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
259
|
+
return noise_pred
|
|
260
|
+
else:
|
|
261
|
+
# cfg by predict noise in one batch
|
|
262
|
+
prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
|
|
263
|
+
latents = torch.cat([latents, latents], dim=0)
|
|
264
|
+
timestep = torch.cat([timestep, timestep], dim=0)
|
|
265
|
+
positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb).chunk(2)
|
|
266
|
+
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
267
|
+
return noise_pred
|
|
268
|
+
|
|
269
|
+
def predict_noise(self, latents, timestep, prompt_emb):
|
|
270
|
+
noise_pred = self.unet(
|
|
271
|
+
x=latents,
|
|
272
|
+
timestep=timestep,
|
|
273
|
+
context=prompt_emb,
|
|
274
|
+
device=self.device,
|
|
275
|
+
)
|
|
276
|
+
return noise_pred
|
|
277
|
+
|
|
278
|
+
def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True):
|
|
279
|
+
self.load_loras([(path, scale)], fused, save_original_weight)
|
|
280
|
+
|
|
281
|
+
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True):
|
|
282
|
+
for lora_path, lora_scale in lora_list:
|
|
283
|
+
state_dict = load_file(lora_path, device="cpu")
|
|
284
|
+
lora_state_dict = self.lora_converter.convert(state_dict)
|
|
285
|
+
for model_name, state_dict in lora_state_dict.items():
|
|
286
|
+
model = getattr(self, model_name)
|
|
287
|
+
for key, param in state_dict.items():
|
|
288
|
+
module = model.get_submodule(key)
|
|
289
|
+
if not isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
290
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
291
|
+
lora_args = {
|
|
292
|
+
"name": key,
|
|
293
|
+
"scale": lora_scale,
|
|
294
|
+
"rank": param["rank"],
|
|
295
|
+
"alpha": param["alpha"],
|
|
296
|
+
"up": param["up"],
|
|
297
|
+
"down": param["down"],
|
|
298
|
+
"device": self.device,
|
|
299
|
+
"dtype": self.dtype,
|
|
300
|
+
"save_original_weight": save_original_weight,
|
|
301
|
+
}
|
|
302
|
+
if fused:
|
|
303
|
+
module.add_frozen_lora(**lora_args)
|
|
304
|
+
else:
|
|
305
|
+
module.add_lora(**lora_args)
|
|
306
|
+
|
|
307
|
+
def unload_loras(self):
|
|
308
|
+
for key, module in self.unet.named_modules():
|
|
309
|
+
if isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
310
|
+
module.clear()
|
|
311
|
+
for key, module in self.text_encoder.named_modules():
|
|
312
|
+
if isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
313
|
+
module.clear()
|
|
314
|
+
|
|
315
|
+
@torch.no_grad()
|
|
316
|
+
def __call__(
|
|
317
|
+
self,
|
|
318
|
+
prompt: str,
|
|
319
|
+
negative_prompt: str = "",
|
|
320
|
+
cfg_scale: float = 7.5,
|
|
321
|
+
clip_skip: int = 1,
|
|
322
|
+
input_image: Optional[Image.Image] = None,
|
|
323
|
+
mask_image: Optional[Image.Image] = None,
|
|
324
|
+
denoising_strength: float = 1.0,
|
|
325
|
+
height: int = 1024,
|
|
326
|
+
width: int = 1024,
|
|
327
|
+
num_inference_steps: int = 20,
|
|
328
|
+
tiled: bool = False,
|
|
329
|
+
tile_size: int = 64,
|
|
330
|
+
tile_stride: int = 32,
|
|
331
|
+
seed: int | None = None,
|
|
332
|
+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
333
|
+
):
|
|
334
|
+
if input_image is not None:
|
|
335
|
+
width, height = input_image.size
|
|
336
|
+
self.validate_image_size(height, width, minimum=64, multiple_of=8)
|
|
337
|
+
noise = self.generate_noise((1, 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.dtype)
|
|
338
|
+
|
|
339
|
+
init_latents, latents, sigmas, timesteps = self.prepare_latents(
|
|
340
|
+
noise, input_image, denoising_strength, num_inference_steps, tiled, tile_size, tile_stride
|
|
341
|
+
)
|
|
342
|
+
mask, overlay_image = None, None
|
|
343
|
+
if mask_image is not None:
|
|
344
|
+
mask, overlay_image = self.prepare_mask(input_image, mask_image, vae_scale_factor=8)
|
|
345
|
+
# Initialize sampler
|
|
346
|
+
self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas, mask=mask)
|
|
347
|
+
|
|
348
|
+
# Encode prompts
|
|
349
|
+
self.load_models_to_device(["text_encoder"])
|
|
350
|
+
positive_prompt_emb = self.encode_prompt(prompt, clip_skip=clip_skip)
|
|
351
|
+
negative_prompt_emb = self.encode_prompt(negative_prompt, clip_skip=clip_skip)
|
|
352
|
+
|
|
353
|
+
# Denoise
|
|
354
|
+
self.load_models_to_device(["unet"])
|
|
355
|
+
for i, timestep in enumerate(tqdm(timesteps)):
|
|
356
|
+
timestep = timestep.unsqueeze(0).to(self.device)
|
|
357
|
+
# Classifier-free guidance
|
|
358
|
+
noise_pred = self.predict_noise_with_cfg(
|
|
359
|
+
latents=latents,
|
|
360
|
+
timestep=timestep,
|
|
361
|
+
positive_prompt_emb=positive_prompt_emb,
|
|
362
|
+
negative_prompt_emb=negative_prompt_emb,
|
|
363
|
+
cfg_scale=cfg_scale,
|
|
364
|
+
batch_cfg=self.batch_cfg,
|
|
365
|
+
)
|
|
366
|
+
# Denoise
|
|
367
|
+
latents = self.sampler.step(latents, noise_pred, i)
|
|
368
|
+
|
|
369
|
+
# UI
|
|
370
|
+
if progress_callback is not None:
|
|
371
|
+
progress_callback(i, len(timesteps), "DENOISING")
|
|
372
|
+
|
|
373
|
+
if mask_image is not None:
|
|
374
|
+
latents = latents * mask + init_latents * (1 - mask)
|
|
375
|
+
# Decode image
|
|
376
|
+
self.load_models_to_device(["vae_decoder"])
|
|
377
|
+
vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
378
|
+
image = self.vae_output_to_image(vae_output)
|
|
379
|
+
# Paste Overlay Image
|
|
380
|
+
if mask_image is not None:
|
|
381
|
+
image = image.convert("RGBA")
|
|
382
|
+
image.alpha_composite(overlay_image)
|
|
383
|
+
image = image.convert("RGB")
|
|
384
|
+
# offload all models
|
|
385
|
+
self.load_models_to_device([])
|
|
386
|
+
return image
|