diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,395 @@
1
+ import os
2
+ import re
3
+ import torch
4
+ from typing import Callable, Dict, Optional
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+
9
+ from diffsynth_engine.models.base import split_suffix
10
+ from diffsynth_engine.models.basic.lora import LoRAContext
11
+ from diffsynth_engine.models.basic.timestep import TemporalTimesteps
12
+ from diffsynth_engine.models.sdxl import (
13
+ SDXLTextEncoder,
14
+ SDXLTextEncoder2,
15
+ SDXLVAEDecoder,
16
+ SDXLVAEEncoder,
17
+ SDXLUNet,
18
+ sdxl_unet_config,
19
+ )
20
+ from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
21
+ from diffsynth_engine.tokenizers import CLIPTokenizer
22
+ from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
23
+ from diffsynth_engine.algorithm.sampler import EulerSampler
24
+ from diffsynth_engine.utils.prompt import tokenize_long_prompt
25
+ from diffsynth_engine.utils.constants import SDXL_TOKENIZER_CONF_PATH, SDXL_TOKENIZER_2_CONF_PATH
26
+ from diffsynth_engine.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class SDXLLoRAConverter(LoRAStateDictConverter):
32
+ def _replace_kohya_te1_key(self, key):
33
+ key = key.replace("lora_te1_text_model_encoder_layers_", "encoders.")
34
+ key = re.sub(r"(\d+)_", r"\1.", key)
35
+ key = key.replace("mlp_fc1", "fc1")
36
+ key = key.replace("mlp_fc2", "fc2")
37
+ key = key.replace("self_attn_q_proj", "attn.to_q")
38
+ key = key.replace("self_attn_k_proj", "attn.to_k")
39
+ key = key.replace("self_attn_v_proj", "attn.to_v")
40
+ key = key.replace("self_attn_out_proj", "attn.to_out")
41
+ return key
42
+
43
+ def _replace_kohya_te2_key(self, key):
44
+ key = key.replace("lora_te2_text_model_encoder_layers_", "encoders.")
45
+ key = re.sub(r"(\d+)_", r"\1.", key)
46
+ key = key.replace("mlp_fc1", "fc1")
47
+ key = key.replace("mlp_fc2", "fc2")
48
+ key = key.replace("self_attn_q_proj", "attn.to_q")
49
+ key = key.replace("self_attn_k_proj", "attn.to_k")
50
+ key = key.replace("self_attn_v_proj", "attn.to_v")
51
+ key = key.replace("self_attn_out_proj", "attn.to_out")
52
+ return key
53
+
54
+ def _replace_kohya_unet_key(self, key):
55
+ rename_dict = sdxl_unet_config["civitai"]["rename_dict"]
56
+ key = key.replace("lora_unet_", "model.diffusion_model.")
57
+ key = key.replace("ff_net", "ff.net")
58
+ key = re.sub(r"(\d+)_", r"\1.", key)
59
+ key = re.sub(r"_(\d+)", r".\1", key)
60
+ name, suffix = split_suffix(key)
61
+ if name not in rename_dict:
62
+ raise ValueError(f"Unsupported key: {key}, name: {name}, suffix: {suffix}")
63
+ key = rename_dict[name] + suffix
64
+ return key
65
+
66
+ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
67
+ unet_dict = {}
68
+ te1_dict = {}
69
+ te2_dict = {}
70
+ for key, param in lora_state_dict.items():
71
+ lora_args = {}
72
+ if ".alpha" not in key:
73
+ continue
74
+ lora_args["alpha"] = param
75
+ lora_args["up"] = lora_state_dict[key.replace(".alpha", ".lora_up.weight")]
76
+ lora_args["down"] = lora_state_dict[key.replace(".alpha", ".lora_down.weight")]
77
+ lora_args["rank"] = lora_args["up"].shape[1]
78
+ key = key.replace(".alpha", "")
79
+ if "lora_te1" in key:
80
+ key = self._replace_kohya_te1_key(key)
81
+ te1_dict[key] = lora_args
82
+ elif "lora_te2" in key:
83
+ key = self._replace_kohya_te2_key(key)
84
+ te2_dict[key] = lora_args
85
+ elif "lora_unet" in key:
86
+ key = self._replace_kohya_unet_key(key)
87
+ unet_dict[key] = lora_args
88
+ else:
89
+ raise ValueError(f"Unsupported key: {key}")
90
+ return {"unet": unet_dict, "text_encoder": te1_dict, "text_encoder_2": te2_dict}
91
+
92
+ def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
93
+ key = list(lora_state_dict.keys())[0]
94
+ if "lora_te1" in key or "lora_te2" in key or "lora_unet" in key:
95
+ return self._from_kohya(lora_state_dict)
96
+ raise ValueError(f"Unsupported key: {key}")
97
+
98
+
99
+ @dataclass
100
+ class SDXLModelConfig:
101
+ unet_path: str | os.PathLike
102
+ clip_l_path: Optional[str | os.PathLike] = None
103
+ clip_g_path: Optional[str | os.PathLike] = None
104
+ vae_path: Optional[str | os.PathLike] = None
105
+
106
+ unet_dtype: torch.dtype = torch.float16
107
+ clip_l_dtype: torch.dtype = torch.float16
108
+ clip_g_dtype: torch.dtype = torch.float16
109
+ vae_dtype: torch.dtype = torch.float32
110
+
111
+
112
+ class SDXLImagePipeline(BasePipeline):
113
+ lora_converter = SDXLLoRAConverter()
114
+
115
+ def __init__(
116
+ self,
117
+ tokenizer: CLIPTokenizer,
118
+ tokenizer_2: CLIPTokenizer,
119
+ text_encoder: SDXLTextEncoder,
120
+ text_encoder_2: SDXLTextEncoder2,
121
+ unet: SDXLUNet,
122
+ vae_decoder: SDXLVAEDecoder,
123
+ vae_encoder: SDXLVAEEncoder,
124
+ batch_cfg: bool = True,
125
+ device: str = "cuda",
126
+ dtype: torch.dtype = torch.float16,
127
+ ):
128
+ super().__init__(device=device, dtype=dtype)
129
+ self.noise_scheduler = ScaledLinearScheduler()
130
+ self.sampler = EulerSampler()
131
+ # models
132
+ self.tokenizer = tokenizer
133
+ self.tokenizer_2 = tokenizer_2
134
+ self.text_encoder = text_encoder
135
+ self.text_encoder_2 = text_encoder_2
136
+ self.unet = unet
137
+ self.vae_decoder = vae_decoder
138
+ self.vae_encoder = vae_encoder
139
+ self.add_time_proj = TemporalTimesteps(
140
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype
141
+ )
142
+ self.batch_cfg = batch_cfg
143
+ self.model_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
144
+
145
+ @classmethod
146
+ def from_pretrained(
147
+ cls,
148
+ model_path_or_config: str | os.PathLike | SDXLModelConfig,
149
+ device: str = "cuda:0",
150
+ dtype: torch.dtype = torch.float16,
151
+ offload_mode: str | None = None,
152
+ batch_cfg: bool = True,
153
+ ) -> "SDXLImagePipeline":
154
+ cls.validate_offload_mode(offload_mode)
155
+
156
+ if isinstance(model_path_or_config, str):
157
+ model_config = SDXLModelConfig(
158
+ unet_path=model_path_or_config, unet_dtype=dtype, clip_l_dtype=dtype, clip_g_dtype=dtype
159
+ )
160
+ else:
161
+ model_config = model_path_or_config
162
+
163
+ logger.info(f"loading state dict from {model_config.unet_path} ...")
164
+ unet_state_dict = cls.load_model_checkpoint(model_config.unet_path, device="cpu", dtype=dtype)
165
+
166
+ if model_config.vae_path is not None:
167
+ logger.info(f"loading state dict from {model_config.vae_path} ...")
168
+ vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
169
+ else:
170
+ vae_state_dict = unet_state_dict
171
+
172
+ if model_config.clip_l_path is not None:
173
+ logger.info(f"loading state dict from {model_config.clip_l_path} ...")
174
+ clip_l_state_dict = cls.load_model_checkpoint(model_config.clip_l_path, device="cpu", dtype=dtype)
175
+ else:
176
+ clip_l_state_dict = unet_state_dict
177
+
178
+ if model_config.clip_g_path is not None:
179
+ logger.info(f"loading state dict from {model_config.clip_g_path} ...")
180
+ clip_g_state_dict = cls.load_model_checkpoint(model_config.clip_g_path, device="cpu", dtype=dtype)
181
+ else:
182
+ clip_g_state_dict = unet_state_dict
183
+
184
+ init_device = "cpu" if offload_mode else device
185
+ tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
186
+ tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
187
+ with LoRAContext():
188
+ text_encoder = SDXLTextEncoder.from_state_dict(
189
+ clip_l_state_dict, device=init_device, dtype=model_config.clip_l_dtype
190
+ )
191
+ text_encoder_2 = SDXLTextEncoder2.from_state_dict(
192
+ clip_g_state_dict, device=init_device, dtype=model_config.clip_g_dtype
193
+ )
194
+ unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=model_config.unet_dtype)
195
+ vae_decoder = SDXLVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
196
+ vae_encoder = SDXLVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype)
197
+
198
+ pipe = cls(
199
+ tokenizer=tokenizer,
200
+ tokenizer_2=tokenizer_2,
201
+ text_encoder=text_encoder,
202
+ text_encoder_2=text_encoder_2,
203
+ unet=unet,
204
+ vae_decoder=vae_decoder,
205
+ vae_encoder=vae_encoder,
206
+ batch_cfg=batch_cfg,
207
+ device=device,
208
+ dtype=dtype,
209
+ )
210
+ if offload_mode == "cpu_offload":
211
+ pipe.enable_cpu_offload()
212
+ elif offload_mode == "sequential_cpu_offload":
213
+ pipe.enable_sequential_cpu_offload()
214
+ return pipe
215
+
216
+ @classmethod
217
+ def from_state_dict(
218
+ cls, state_dict: Dict[str, torch.Tensor], device: str = "cuda:0", dtype: torch.dtype = torch.float16
219
+ ) -> "SDXLImagePipeline":
220
+ raise NotImplementedError()
221
+
222
+ def denoising_model(self):
223
+ return self.unet
224
+
225
+ def encode_prompt(self, prompt, clip_skip):
226
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(self.device)
227
+ prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
228
+
229
+ input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(self.device)
230
+ prompt_emb_2, add_text_embeds = self.text_encoder_2(input_ids_2, clip_skip=clip_skip)
231
+
232
+ # Merge
233
+ if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
234
+ max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
235
+ prompt_emb_1 = prompt_emb_1[:max_batch_size]
236
+ prompt_emb_2 = prompt_emb_2[:max_batch_size]
237
+ prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
238
+
239
+ # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
240
+ add_text_embeds = add_text_embeds[0:1]
241
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0] * prompt_emb.shape[1], -1))
242
+
243
+ return prompt_emb, add_text_embeds
244
+
245
+ def prepare_add_time_id(self, latents):
246
+ height, width = latents.shape[2] * 8, latents.shape[3] * 8
247
+ add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
248
+ # original_size_as_tuple(height, width)
249
+ # crop_coords_top_left(0, 0)
250
+ # target_size_as_tuple(height, width)
251
+ return add_time_id
252
+
253
+ def prepare_add_embeds(self, add_text_embeds, add_time_id, dtype):
254
+ time_embeds = self.add_time_proj(add_time_id)
255
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
256
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
257
+ add_embeds = add_embeds.to(dtype)
258
+ return add_embeds
259
+
260
+ def predict_noise_with_cfg(
261
+ self,
262
+ latents: torch.Tensor,
263
+ timestep: torch.Tensor,
264
+ positive_prompt_emb: torch.Tensor,
265
+ negative_prompt_emb: torch.Tensor,
266
+ positive_add_text_embeds: torch.Tensor,
267
+ negative_add_text_embeds: torch.Tensor,
268
+ add_time_id: torch.Tensor,
269
+ cfg_scale: float,
270
+ batch_cfg: bool = True,
271
+ ):
272
+ if cfg_scale <= 1.0:
273
+ return self.predict_noise(latents, timestep, positive_prompt_emb, add_time_id)
274
+ if not batch_cfg:
275
+ # cfg by predict noise one by one
276
+ positive_noise_pred = self.predict_noise(
277
+ latents, timestep, positive_prompt_emb, positive_add_text_embeds, add_time_id
278
+ )
279
+ negative_noise_pred = self.predict_noise(
280
+ latents, timestep, negative_prompt_emb, negative_add_text_embeds, add_time_id
281
+ )
282
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
283
+ return noise_pred
284
+ else:
285
+ # cfg by predict noise in one batch
286
+ add_time_ids = torch.cat([add_time_id, add_time_id], dim=0)
287
+ prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
288
+ add_text_embeds = torch.cat([positive_add_text_embeds, negative_add_text_embeds], dim=0)
289
+ latents = torch.cat([latents, latents], dim=0)
290
+ timestep = torch.cat([timestep, timestep], dim=0)
291
+ positive_noise_pred, negative_noise_pred = self.predict_noise(
292
+ latents, timestep, prompt_emb, add_text_embeds, add_time_ids
293
+ )
294
+ noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
295
+ return noise_pred
296
+
297
+ def predict_noise(self, latents, timestep, prompt_emb, add_text_embeds, add_time_id):
298
+ y = self.prepare_add_embeds(add_text_embeds, add_time_id, self.dtype)
299
+ noise_pred = self.unet(
300
+ x=latents,
301
+ timestep=timestep,
302
+ y=y,
303
+ context=prompt_emb,
304
+ device=self.device,
305
+ )
306
+ return noise_pred
307
+
308
+ def unload_loras(self):
309
+ self.unet.unload_loras()
310
+ self.text_encoder.unload_loras()
311
+ self.text_encoder_2.unload_loras()
312
+
313
+ @torch.no_grad()
314
+ def __call__(
315
+ self,
316
+ prompt: str,
317
+ negative_prompt: str = "",
318
+ cfg_scale: float = 7.5,
319
+ clip_skip: int = 2,
320
+ input_image: Image.Image | None = None,
321
+ mask_image: Image.Image | None = None,
322
+ denoising_strength: float = 1.0,
323
+ height: int = 1024,
324
+ width: int = 1024,
325
+ num_inference_steps: int = 20,
326
+ tiled: bool = False,
327
+ tile_size: int = 64,
328
+ tile_stride: int = 32,
329
+ seed: int | None = None,
330
+ progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
331
+ ):
332
+ if input_image is not None:
333
+ width, height = input_image.size
334
+ self.validate_image_size(height, width, minimum=64, multiple_of=8)
335
+ noise = self.generate_noise((1, 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.dtype)
336
+
337
+ init_latents, latents, sigmas, timesteps = self.prepare_latents(
338
+ noise, input_image, denoising_strength, num_inference_steps, tiled, tile_size, tile_stride
339
+ )
340
+ mask, overlay_image = None, None
341
+ if mask_image is not None:
342
+ mask, overlay_image = self.prepare_mask(input_image, mask_image, vae_scale_factor=8)
343
+ # Initialize sampler
344
+ self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas, mask=mask)
345
+
346
+ # Encode prompts
347
+ self.load_models_to_device(["text_encoder", "text_encoder_2"])
348
+ positive_prompt_emb, positive_add_text_embeds = self.encode_prompt(prompt, clip_skip=clip_skip)
349
+ if negative_prompt != "":
350
+ negative_prompt_emb, negative_add_text_embeds = self.encode_prompt(negative_prompt, clip_skip=clip_skip)
351
+ else:
352
+ # from automatic1111/stable-diffusion-webui
353
+ negative_prompt_emb, negative_add_text_embeds = (
354
+ torch.zeros_like(positive_prompt_emb),
355
+ torch.zeros_like(positive_add_text_embeds),
356
+ )
357
+
358
+ # Prepare extra input
359
+ add_time_id = self.prepare_add_time_id(latents)
360
+
361
+ # Denoise
362
+ self.load_models_to_device(["unet"])
363
+ for i, timestep in enumerate(tqdm(timesteps)):
364
+ timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
365
+ # Classifier-free guidance
366
+ noise_pred = self.predict_noise_with_cfg(
367
+ latents=latents,
368
+ timestep=timestep,
369
+ positive_prompt_emb=positive_prompt_emb,
370
+ negative_prompt_emb=negative_prompt_emb,
371
+ positive_add_text_embeds=positive_add_text_embeds,
372
+ negative_add_text_embeds=negative_add_text_embeds,
373
+ add_time_id=add_time_id,
374
+ cfg_scale=cfg_scale,
375
+ batch_cfg=self.batch_cfg,
376
+ )
377
+ # Denoise
378
+ latents = self.sampler.step(latents, noise_pred, i)
379
+ # UI
380
+ if progress_callback is not None:
381
+ progress_callback(i, len(timesteps), "DENOISING")
382
+ if mask_image is not None:
383
+ latents = latents * mask + init_latents * (1 - mask)
384
+ # Decode image
385
+ self.load_models_to_device(["vae_decoder"])
386
+ vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
387
+ image = self.vae_output_to_image(vae_output)
388
+
389
+ if mask_image is not None:
390
+ image = image.convert("RGBA")
391
+ image.alpha_composite(overlay_image)
392
+ image = image.convert("RGB")
393
+ # offload all models
394
+ self.load_models_to_device([])
395
+ return image