diffsynth-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. diffsynth_engine/__init__.py +25 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +48 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +28 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +20 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  30. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  31. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  32. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  33. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  34. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  35. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  36. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  37. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  38. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  39. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  40. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  41. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  42. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  43. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  44. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  45. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  46. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  47. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  48. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  49. diffsynth_engine/models/__init__.py +0 -0
  50. diffsynth_engine/models/base.py +55 -0
  51. diffsynth_engine/models/basic/__init__.py +0 -0
  52. diffsynth_engine/models/basic/attention.py +137 -0
  53. diffsynth_engine/models/basic/lora.py +293 -0
  54. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  55. diffsynth_engine/models/basic/timestep.py +81 -0
  56. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  57. diffsynth_engine/models/basic/unet_helper.py +244 -0
  58. diffsynth_engine/models/components/__init__.py +0 -0
  59. diffsynth_engine/models/components/clip.py +56 -0
  60. diffsynth_engine/models/components/t5.py +222 -0
  61. diffsynth_engine/models/components/vae.py +393 -0
  62. diffsynth_engine/models/flux/__init__.py +14 -0
  63. diffsynth_engine/models/flux/flux_dit.py +504 -0
  64. diffsynth_engine/models/flux/flux_text_encoder.py +90 -0
  65. diffsynth_engine/models/flux/flux_vae.py +78 -0
  66. diffsynth_engine/models/sd/__init__.py +12 -0
  67. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  68. diffsynth_engine/models/sd/sd_unet.py +293 -0
  69. diffsynth_engine/models/sd/sd_vae.py +38 -0
  70. diffsynth_engine/models/sd3/__init__.py +14 -0
  71. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  72. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  73. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  74. diffsynth_engine/models/sdxl/__init__.py +13 -0
  75. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  76. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  77. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  78. diffsynth_engine/models/utils.py +54 -0
  79. diffsynth_engine/models/wan/__init__.py +0 -0
  80. diffsynth_engine/models/wan/attention.py +200 -0
  81. diffsynth_engine/models/wan/wan_dit.py +431 -0
  82. diffsynth_engine/models/wan/wan_image_encoder.py +495 -0
  83. diffsynth_engine/models/wan/wan_text_encoder.py +264 -0
  84. diffsynth_engine/models/wan/wan_vae.py +771 -0
  85. diffsynth_engine/pipelines/__init__.py +17 -0
  86. diffsynth_engine/pipelines/base.py +216 -0
  87. diffsynth_engine/pipelines/flux_image.py +548 -0
  88. diffsynth_engine/pipelines/sd_image.py +386 -0
  89. diffsynth_engine/pipelines/sdxl_image.py +430 -0
  90. diffsynth_engine/pipelines/wan_video.py +481 -0
  91. diffsynth_engine/tokenizers/__init__.py +4 -0
  92. diffsynth_engine/tokenizers/base.py +157 -0
  93. diffsynth_engine/tokenizers/clip.py +288 -0
  94. diffsynth_engine/tokenizers/t5.py +194 -0
  95. diffsynth_engine/tokenizers/wan.py +79 -0
  96. diffsynth_engine/utils/__init__.py +0 -0
  97. diffsynth_engine/utils/constants.py +34 -0
  98. diffsynth_engine/utils/download.py +139 -0
  99. diffsynth_engine/utils/env.py +7 -0
  100. diffsynth_engine/utils/fp8_linear.py +64 -0
  101. diffsynth_engine/utils/gguf.py +415 -0
  102. diffsynth_engine/utils/loader.py +14 -0
  103. diffsynth_engine/utils/lock.py +56 -0
  104. diffsynth_engine/utils/logging.py +12 -0
  105. diffsynth_engine/utils/offload.py +44 -0
  106. diffsynth_engine/utils/parallel.py +191 -0
  107. diffsynth_engine/utils/prompt.py +9 -0
  108. diffsynth_engine/utils/video.py +40 -0
  109. diffsynth_engine-0.1.0.dist-info/LICENSE +201 -0
  110. diffsynth_engine-0.1.0.dist-info/METADATA +237 -0
  111. diffsynth_engine-0.1.0.dist-info/RECORD +113 -0
  112. diffsynth_engine-0.1.0.dist-info/WHEEL +5 -0
  113. diffsynth_engine-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,386 @@
1
+ import re
2
+ import os
3
+ import torch
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Dict, Optional, List, Tuple
6
+ from safetensors.torch import load_file
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+
10
+ from diffsynth_engine.models.base import LoRAStateDictConverter, split_suffix
11
+ from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d
12
+ from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
13
+ from diffsynth_engine.pipelines import BasePipeline
14
+ from diffsynth_engine.tokenizers import CLIPTokenizer
15
+ from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
16
+ from diffsynth_engine.algorithm.sampler import EulerSampler
17
+ from diffsynth_engine.utils.prompt import tokenize_long_prompt
18
+ from diffsynth_engine.utils.constants import SDXL_TOKENIZER_CONF_PATH
19
+ from diffsynth_engine.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ re_compiled = {}
24
+ re_digits = re.compile(r"\d+")
25
+ suffix_conversion = {
26
+ "attentions": {},
27
+ "resnets": {
28
+ "conv1": "in_layers_2",
29
+ "conv2": "out_layers_3",
30
+ "norm1": "in_layers_0",
31
+ "norm2": "out_layers_0",
32
+ "time_emb_proj": "emb_layers_1",
33
+ "conv_shortcut": "skip_connection",
34
+ },
35
+ }
36
+
37
+
38
+ def convert_diffusers_name_to_compvis(key):
39
+ def match(match_list, regex_text):
40
+ regex = re_compiled.get(regex_text)
41
+ if regex is None:
42
+ regex = re.compile(regex_text)
43
+ re_compiled[regex_text] = regex
44
+
45
+ r = re.match(regex, key)
46
+ if not r:
47
+ return False
48
+
49
+ match_list.clear()
50
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
51
+ return True
52
+
53
+ m = []
54
+
55
+ if match(m, r"lora_unet_conv_in(.*)"):
56
+ return f"model.diffusion_model.input_blocks.0.0{m[0]}"
57
+
58
+ if match(m, r"lora_unet_conv_out(.*)"):
59
+ return f"model.diffusion_model.out.2{m[0]}"
60
+
61
+ if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
62
+ return f"model.diffusion_model.time_embed_{m[0] * 2 - 2}{m[1]}"
63
+
64
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
65
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
66
+ return f"model.diffusion_model.input_blocks.{1 + m[0] * 3 + m[2]}.{1 if m[1] == 'attentions' else 0}.{suffix}"
67
+
68
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
69
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
70
+ return f"model.diffusion_model.middle_block.{1 if m[0] == 'attentions' else m[1] * 2}.{suffix}"
71
+
72
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
73
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
74
+ return f"model.diffusion_model.output_blocks.{m[0] * 3 + m[2]}.{1 if m[1] == 'attentions' else 0}.{suffix}"
75
+
76
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
77
+ return f"model.diffusion_model.input_blocks.{3 + m[0] * 3}.0.op"
78
+
79
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
80
+ return f"model.diffusion_model.output_blocks.{2 + m[0] * 3}.{2 if m[0] > 0 else 1}.conv"
81
+ return key
82
+
83
+
84
+ @dataclass
85
+ class SDModelConfig:
86
+ unet_path: str | os.PathLike
87
+ clip_path: Optional[str | os.PathLike] = None
88
+ vae_path: Optional[str | os.PathLike] = None
89
+
90
+ unet_dtype: torch.dtype = torch.float16
91
+ clip_dtype: torch.dtype = torch.float16
92
+ vae_dtype: torch.dtype = torch.float32
93
+
94
+
95
+ class SDLoRAConverter(LoRAStateDictConverter):
96
+ def _replace_kohya_te_key(self, key):
97
+ key = key.replace("lora_te_text_model_encoder_layers_", "encoders.")
98
+ key = re.sub(r"(\d+)_", r"\1.", key)
99
+ key = key.replace("mlp_fc1", "fc1")
100
+ key = key.replace("mlp_fc2", "fc2")
101
+ key = key.replace("self_attn_q_proj", "attn.to_q")
102
+ key = key.replace("self_attn_k_proj", "attn.to_k")
103
+ key = key.replace("self_attn_v_proj", "attn.to_v")
104
+ key = key.replace("self_attn_out_proj", "attn.to_out")
105
+ return key
106
+
107
+ def _replace_kohya_unet_key(self, key):
108
+ rename_dict = sd_unet_config["civitai"]["rename_dict"]
109
+ key = convert_diffusers_name_to_compvis(key)
110
+ key = re.sub(r"(\d+)_", r"\1.", key)
111
+ key = re.sub(r"_(\d+)", r".\1", key)
112
+ key = key.replace("ff_net", "ff.net")
113
+ name, suffix = split_suffix(key)
114
+ if name not in rename_dict:
115
+ raise ValueError(f"Unsupported key: {key}, name: {name}, suffix: {suffix}")
116
+ key = rename_dict[name] + suffix
117
+ return key
118
+
119
+ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
120
+ unet_dict = {}
121
+ te_dict = {}
122
+ for key, param in lora_state_dict.items():
123
+ lora_args = {}
124
+ if ".alpha" not in key:
125
+ continue
126
+ lora_args["alpha"] = param
127
+ lora_args["up"] = lora_state_dict[key.replace(".alpha", ".lora_up.weight")].squeeze()
128
+ lora_args["down"] = lora_state_dict[key.replace(".alpha", ".lora_down.weight")].squeeze()
129
+ lora_args["rank"] = lora_args["up"].shape[1]
130
+ key = key.replace(".alpha", "")
131
+ if "lora_unet" in key:
132
+ key = self._replace_kohya_unet_key(key)
133
+ unet_dict[key] = lora_args
134
+ elif "lora_te" in key:
135
+ key = self._replace_kohya_te_key(key)
136
+ te_dict[key] = lora_args
137
+ return {"unet": unet_dict, "text_encoder": te_dict}
138
+
139
+ def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
140
+ key = list(lora_state_dict.keys())[0]
141
+ if "lora_te" in key or "lora_unet" in key:
142
+ return self._from_kohya(lora_state_dict)
143
+ raise ValueError(f"Unsupported key: {key}")
144
+
145
+
146
+ class SDImagePipeline(BasePipeline):
147
+ lora_converter = SDLoRAConverter()
148
+
149
+ def __init__(
150
+ self,
151
+ tokenizer: CLIPTokenizer,
152
+ text_encoder: SDTextEncoder,
153
+ unet: SDUNet,
154
+ vae_decoder: SDVAEDecoder,
155
+ vae_encoder: SDVAEEncoder,
156
+ batch_cfg: bool = True,
157
+ device: str = "cuda",
158
+ dtype: torch.dtype = torch.float16,
159
+ ):
160
+ super().__init__(device=device, dtype=dtype)
161
+ self.noise_scheduler = ScaledLinearScheduler()
162
+ self.sampler = EulerSampler()
163
+ # models
164
+ self.tokenizer = tokenizer
165
+ self.text_encoder = text_encoder
166
+ self.unet = unet
167
+ self.vae_decoder = vae_decoder
168
+ self.vae_encoder = vae_encoder
169
+ self.batch_cfg = batch_cfg
170
+ self.model_names = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
171
+
172
+ @classmethod
173
+ def from_pretrained(
174
+ cls,
175
+ model_path_or_config: str | os.PathLike | SDModelConfig,
176
+ device: str = "cuda:0",
177
+ dtype: torch.dtype = torch.float16,
178
+ offload_mode: str | None = None,
179
+ batch_cfg: bool = True,
180
+ ) -> "SDImagePipeline":
181
+ cls.validate_offload_mode(offload_mode)
182
+
183
+ if isinstance(model_path_or_config, str):
184
+ model_config = SDModelConfig(unet_path=model_path_or_config)
185
+ else:
186
+ model_config = model_path_or_config
187
+
188
+ logger.info(f"loading state dict from {model_config.unet_path} ...")
189
+ unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
190
+
191
+ if model_config.vae_path is not None:
192
+ logger.info(f"loading state dict from {model_config.vae_path} ...")
193
+ vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
194
+ else:
195
+ vae_state_dict = unet_state_dict
196
+
197
+ if model_config.clip_path is not None:
198
+ logger.info(f"loading state dict from {model_config.clip_path} ...")
199
+ clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
200
+ else:
201
+ clip_state_dict = unet_state_dict
202
+
203
+ init_device = "cpu" if offload_mode else device
204
+ tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
205
+ with LoRAContext():
206
+ text_encoder = SDTextEncoder.from_state_dict(
207
+ clip_state_dict, device=init_device, dtype=model_config.clip_dtype
208
+ )
209
+ unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
210
+ vae_decoder = SDVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
211
+ vae_encoder = SDVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
212
+
213
+ pipe = cls(
214
+ tokenizer=tokenizer,
215
+ text_encoder=text_encoder,
216
+ unet=unet,
217
+ vae_decoder=vae_decoder,
218
+ vae_encoder=vae_encoder,
219
+ batch_cfg=batch_cfg,
220
+ device=device,
221
+ dtype=dtype,
222
+ )
223
+ if offload_mode == "cpu_offload":
224
+ pipe.enable_cpu_offload()
225
+ elif offload_mode == "sequential_cpu_offload":
226
+ pipe.enable_sequential_cpu_offload()
227
+ return pipe
228
+
229
+ @classmethod
230
+ def from_state_dict(
231
+ cls, state_dict: Dict[str, torch.Tensor], device: str = "cuda:0", dtype: torch.dtype = torch.float16
232
+ ) -> "SDImagePipeline":
233
+ raise NotImplementedError()
234
+
235
+ def denoising_model(self):
236
+ return self.unet
237
+
238
+ def encode_prompt(self, prompt, clip_skip):
239
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(self.device)
240
+ prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
241
+ return prompt_emb
242
+
243
+ def predict_noise_with_cfg(
244
+ self,
245
+ latents: torch.Tensor,
246
+ timestep: torch.Tensor,
247
+ positive_prompt_emb: torch.Tensor,
248
+ negative_prompt_emb: torch.Tensor,
249
+ cfg_scale: float,
250
+ batch_cfg: bool = True,
251
+ ):
252
+ if cfg_scale < 1.0:
253
+ return self.predict_noise(latents, timestep, positive_prompt_emb)
254
+ if not batch_cfg:
255
+ # cfg by predict noise one by one
256
+ positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb)
257
+ negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb)
258
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
259
+ return noise_pred
260
+ else:
261
+ # cfg by predict noise in one batch
262
+ prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
263
+ latents = torch.cat([latents, latents], dim=0)
264
+ timestep = torch.cat([timestep, timestep], dim=0)
265
+ positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb).chunk(2)
266
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
267
+ return noise_pred
268
+
269
+ def predict_noise(self, latents, timestep, prompt_emb):
270
+ noise_pred = self.unet(
271
+ x=latents,
272
+ timestep=timestep,
273
+ context=prompt_emb,
274
+ device=self.device,
275
+ )
276
+ return noise_pred
277
+
278
+ def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True):
279
+ self.load_loras([(path, scale)], fused, save_original_weight)
280
+
281
+ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True):
282
+ for lora_path, lora_scale in lora_list:
283
+ state_dict = load_file(lora_path, device="cpu")
284
+ lora_state_dict = self.lora_converter.convert(state_dict)
285
+ for model_name, state_dict in lora_state_dict.items():
286
+ model = getattr(self, model_name)
287
+ for key, param in state_dict.items():
288
+ module = model.get_submodule(key)
289
+ if not isinstance(module, (LoRALinear, LoRAConv2d)):
290
+ raise ValueError(f"Unsupported lora key: {key}")
291
+ lora_args = {
292
+ "name": key,
293
+ "scale": lora_scale,
294
+ "rank": param["rank"],
295
+ "alpha": param["alpha"],
296
+ "up": param["up"],
297
+ "down": param["down"],
298
+ "device": self.device,
299
+ "dtype": self.dtype,
300
+ "save_original_weight": save_original_weight,
301
+ }
302
+ if fused:
303
+ module.add_frozen_lora(**lora_args)
304
+ else:
305
+ module.add_lora(**lora_args)
306
+
307
+ def unload_loras(self):
308
+ for key, module in self.unet.named_modules():
309
+ if isinstance(module, (LoRALinear, LoRAConv2d)):
310
+ module.clear()
311
+ for key, module in self.text_encoder.named_modules():
312
+ if isinstance(module, (LoRALinear, LoRAConv2d)):
313
+ module.clear()
314
+
315
+ @torch.no_grad()
316
+ def __call__(
317
+ self,
318
+ prompt: str,
319
+ negative_prompt: str = "",
320
+ cfg_scale: float = 7.5,
321
+ clip_skip: int = 1,
322
+ input_image: Optional[Image.Image] = None,
323
+ mask_image: Optional[Image.Image] = None,
324
+ denoising_strength: float = 1.0,
325
+ height: int = 1024,
326
+ width: int = 1024,
327
+ num_inference_steps: int = 20,
328
+ tiled: bool = False,
329
+ tile_size: int = 64,
330
+ tile_stride: int = 32,
331
+ seed: int | None = None,
332
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
333
+ ):
334
+ if input_image is not None:
335
+ width, height = input_image.size
336
+ self.validate_image_size(height, width, minimum=64, multiple_of=8)
337
+ noise = self.generate_noise((1, 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.dtype)
338
+
339
+ init_latents, latents, sigmas, timesteps = self.prepare_latents(
340
+ noise, input_image, denoising_strength, num_inference_steps, tiled, tile_size, tile_stride
341
+ )
342
+ mask, overlay_image = None, None
343
+ if mask_image is not None:
344
+ mask, overlay_image = self.prepare_mask(input_image, mask_image, vae_scale_factor=8)
345
+ # Initialize sampler
346
+ self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas, mask=mask)
347
+
348
+ # Encode prompts
349
+ self.load_models_to_device(["text_encoder"])
350
+ positive_prompt_emb = self.encode_prompt(prompt, clip_skip=clip_skip)
351
+ negative_prompt_emb = self.encode_prompt(negative_prompt, clip_skip=clip_skip)
352
+
353
+ # Denoise
354
+ self.load_models_to_device(["unet"])
355
+ for i, timestep in enumerate(tqdm(timesteps)):
356
+ timestep = timestep.unsqueeze(0).to(self.device)
357
+ # Classifier-free guidance
358
+ noise_pred = self.predict_noise_with_cfg(
359
+ latents=latents,
360
+ timestep=timestep,
361
+ positive_prompt_emb=positive_prompt_emb,
362
+ negative_prompt_emb=negative_prompt_emb,
363
+ cfg_scale=cfg_scale,
364
+ batch_cfg=self.batch_cfg,
365
+ )
366
+ # Denoise
367
+ latents = self.sampler.step(latents, noise_pred, i)
368
+
369
+ # UI
370
+ if progress_callback is not None:
371
+ progress_callback(i, len(timesteps), "DENOISING")
372
+
373
+ if mask_image is not None:
374
+ latents = latents * mask + init_latents * (1 - mask)
375
+ # Decode image
376
+ self.load_models_to_device(["vae_decoder"])
377
+ vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
378
+ image = self.vae_output_to_image(vae_output)
379
+ # Paste Overlay Image
380
+ if mask_image is not None:
381
+ image = image.convert("RGBA")
382
+ image.alpha_composite(overlay_image)
383
+ image = image.convert("RGB")
384
+ # offload all models
385
+ self.load_models_to_device([])
386
+ return image