hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/format/__init__.py +2 -2
- hcpdiff/ckpt_manager/format/diffusers.py +19 -4
- hcpdiff/ckpt_manager/format/emb.py +8 -3
- hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
- hcpdiff/ckpt_manager/format/sd_single.py +28 -5
- hcpdiff/data/cache/vae.py +10 -2
- hcpdiff/data/handler/text.py +15 -14
- hcpdiff/diffusion/sampler/__init__.py +2 -1
- hcpdiff/diffusion/sampler/base.py +17 -6
- hcpdiff/diffusion/sampler/diffusers.py +4 -3
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
- hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
- hcpdiff/diffusion/sampler/timer/base.py +26 -0
- hcpdiff/diffusion/sampler/timer/shift.py +49 -0
- hcpdiff/easy/__init__.py +2 -1
- hcpdiff/easy/cfg/sd15_train.py +1 -3
- hcpdiff/easy/model/__init__.py +1 -1
- hcpdiff/easy/model/loader.py +33 -11
- hcpdiff/easy/sampler.py +8 -1
- hcpdiff/loss/__init__.py +4 -3
- hcpdiff/loss/charbonnier.py +17 -0
- hcpdiff/loss/vlb.py +2 -2
- hcpdiff/loss/weighting.py +29 -11
- hcpdiff/models/__init__.py +1 -1
- hcpdiff/models/cfg_context.py +5 -3
- hcpdiff/models/compose/__init__.py +2 -1
- hcpdiff/models/compose/compose_hook.py +69 -67
- hcpdiff/models/compose/compose_textencoder.py +59 -45
- hcpdiff/models/compose/compose_tokenizer.py +48 -11
- hcpdiff/models/compose/flux.py +75 -0
- hcpdiff/models/compose/sdxl.py +86 -0
- hcpdiff/models/text_emb_ex.py +13 -9
- hcpdiff/models/textencoder_ex.py +8 -38
- hcpdiff/models/wrapper/__init__.py +2 -1
- hcpdiff/models/wrapper/flux.py +75 -0
- hcpdiff/models/wrapper/pixart.py +13 -1
- hcpdiff/models/wrapper/sd.py +17 -8
- hcpdiff/parser/embpt.py +7 -7
- hcpdiff/utils/net_utils.py +22 -12
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +145 -18
- hcpdiff/workflow/text.py +49 -18
- hcpdiff/workflow/vae.py +10 -2
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
- hcpdiff/models/compose/sdxl_composer.py +0 -39
- hcpdiff/utils/inpaint_pipe.py +0 -790
- hcpdiff/utils/pipe_hook.py +0 -656
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
hcpdiff/utils/pipe_hook.py
DELETED
@@ -1,656 +0,0 @@
|
|
1
|
-
from typing import Union, List, Optional, Callable, Dict, Any
|
2
|
-
|
3
|
-
import PIL
|
4
|
-
import torch
|
5
|
-
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
|
6
|
-
from diffusers.image_processor import VaeImageProcessor
|
7
|
-
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
8
|
-
from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
|
9
|
-
from einops import repeat
|
10
|
-
|
11
|
-
class HookPipe_T2I(StableDiffusionPipeline):
|
12
|
-
@property
|
13
|
-
def _execution_device(self) -> torch.device:
|
14
|
-
return torch.device('cuda')
|
15
|
-
|
16
|
-
@property
|
17
|
-
def device(self) -> torch.device:
|
18
|
-
return torch.device('cuda')
|
19
|
-
|
20
|
-
def proc_prompt(self, device, num_inference_steps, prompt_embeds = None, negative_prompt_embeds = None) -> List[torch.Tensor]:
|
21
|
-
if not isinstance(prompt_embeds, list): # to emb for each step
|
22
|
-
prompt_embeds = [prompt_embeds]*num_inference_steps
|
23
|
-
if not isinstance(negative_prompt_embeds, list): # to emb for each step
|
24
|
-
negative_prompt_embeds = [negative_prompt_embeds]*num_inference_steps
|
25
|
-
|
26
|
-
prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in prompt_embeds]
|
27
|
-
negative_prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in negative_prompt_embeds]
|
28
|
-
|
29
|
-
prompt_embeds = [torch.cat([emb_neg, emb_pos]) for emb_pos, emb_neg in zip(prompt_embeds, negative_prompt_embeds)]
|
30
|
-
return prompt_embeds # List[emb_step_i]*num_inference_steps
|
31
|
-
|
32
|
-
@torch.no_grad()
|
33
|
-
def __call__(
|
34
|
-
self,
|
35
|
-
prompt: Union[str, List[str]] = None,
|
36
|
-
height: Optional[int] = None,
|
37
|
-
width: Optional[int] = None,
|
38
|
-
num_inference_steps: int = 50,
|
39
|
-
guidance_scale: float = 7.5,
|
40
|
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
41
|
-
eta: float = 0.0,
|
42
|
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
43
|
-
latents: Optional[torch.FloatTensor] = None,
|
44
|
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
45
|
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
46
|
-
output_type: Optional[str] = "pil",
|
47
|
-
return_dict: bool = True,
|
48
|
-
callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
|
49
|
-
callback_steps: int = 1,
|
50
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
51
|
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
52
|
-
pooled_output: Optional[torch.FloatTensor] = None,
|
53
|
-
crop_coord: Optional[torch.FloatTensor] = None,
|
54
|
-
**kwargs
|
55
|
-
):
|
56
|
-
# 0. Default height and width to unet
|
57
|
-
height = height or self.unet.config.sample_size*self.vae_scale_factor
|
58
|
-
width = width or self.unet.config.sample_size*self.vae_scale_factor
|
59
|
-
|
60
|
-
# 1. Check inputs. Raise error if not correct
|
61
|
-
# self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
62
|
-
|
63
|
-
# 2. Define call parameters
|
64
|
-
if prompt is not None and isinstance(prompt, str):
|
65
|
-
batch_size = 1
|
66
|
-
elif prompt is not None and isinstance(prompt, list):
|
67
|
-
batch_size = len(prompt)
|
68
|
-
elif isinstance(prompt_embeds, list):
|
69
|
-
batch_size = prompt_embeds[0].shape[0]
|
70
|
-
else:
|
71
|
-
batch_size = prompt_embeds.shape[0]
|
72
|
-
|
73
|
-
device = self._execution_device
|
74
|
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
75
|
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
76
|
-
# corresponds to doing no classifier free guidance.
|
77
|
-
do_classifier_free_guidance = guidance_scale>1.0
|
78
|
-
|
79
|
-
# 3. Encode input prompt
|
80
|
-
prompt_embeds = self.proc_prompt(device, num_inference_steps,
|
81
|
-
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds)
|
82
|
-
|
83
|
-
# 4. Prepare timesteps
|
84
|
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
85
|
-
timesteps = self.scheduler.timesteps
|
86
|
-
alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
|
87
|
-
|
88
|
-
# 5. Prepare latent variables
|
89
|
-
num_channels_latents = self.unet.config.in_channels
|
90
|
-
latents = self.prepare_latents(
|
91
|
-
batch_size,
|
92
|
-
num_channels_latents,
|
93
|
-
height,
|
94
|
-
width,
|
95
|
-
prompt_embeds[0].dtype,
|
96
|
-
device,
|
97
|
-
generator,
|
98
|
-
latents,
|
99
|
-
)
|
100
|
-
|
101
|
-
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
102
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
103
|
-
|
104
|
-
# SDXL inputs
|
105
|
-
if pooled_output is not None:
|
106
|
-
if crop_coord is None:
|
107
|
-
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
108
|
-
else:
|
109
|
-
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
110
|
-
crop_info = crop_info.to(device).repeat(batch_size, 1)
|
111
|
-
pooled_output = pooled_output.to(device)
|
112
|
-
|
113
|
-
if do_classifier_free_guidance:
|
114
|
-
crop_info = torch.cat([crop_info, crop_info], dim=0)
|
115
|
-
|
116
|
-
# 7. Denoising loop
|
117
|
-
num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
|
118
|
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
119
|
-
for i, t in enumerate(timesteps):
|
120
|
-
# expand the latents if we are doing classifier free guidance
|
121
|
-
latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
|
122
|
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
123
|
-
|
124
|
-
if pooled_output is None:
|
125
|
-
if isinstance(self.unet, PixArtTransformer2DModel):
|
126
|
-
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
127
|
-
noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
|
128
|
-
encoder_attention_mask=encoder_attention_mask,
|
129
|
-
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
130
|
-
else:
|
131
|
-
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
132
|
-
encoder_attention_mask=encoder_attention_mask,
|
133
|
-
cross_attention_kwargs=cross_attention_kwargs).sample
|
134
|
-
else:
|
135
|
-
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
136
|
-
# predict the noise residual
|
137
|
-
noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
|
138
|
-
encoder_attention_mask=encoder_attention_mask,
|
139
|
-
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
140
|
-
|
141
|
-
# perform guidance
|
142
|
-
if do_classifier_free_guidance:
|
143
|
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
144
|
-
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
145
|
-
|
146
|
-
# learned sigma
|
147
|
-
if self.unet.config.out_channels // 2 == num_channels_latents:
|
148
|
-
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
149
|
-
|
150
|
-
# x_t -> x_0
|
151
|
-
alpha_prod_t = alphas_cumprod[t.long()]
|
152
|
-
beta_prod_t = 1-alpha_prod_t
|
153
|
-
latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
|
154
|
-
|
155
|
-
# compute the previous noisy sample x_t -> x_t-1
|
156
|
-
sc_out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
157
|
-
latents = sc_out.prev_sample
|
158
|
-
|
159
|
-
# call the callback, if provided
|
160
|
-
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
161
|
-
progress_bar.update()
|
162
|
-
if callback is not None and i%callback_steps == 0:
|
163
|
-
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
164
|
-
if latents is None:
|
165
|
-
return None
|
166
|
-
|
167
|
-
latents = latents.to(dtype=self.vae.dtype)
|
168
|
-
if not output_type == "latent":
|
169
|
-
image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
|
170
|
-
else:
|
171
|
-
image = latents
|
172
|
-
|
173
|
-
do_denormalize = [True]*image.shape[0]
|
174
|
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
175
|
-
|
176
|
-
# Offload last model to CPU
|
177
|
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
178
|
-
self.final_offload_hook.offload()
|
179
|
-
|
180
|
-
if not return_dict:
|
181
|
-
return (image, None)
|
182
|
-
|
183
|
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|
184
|
-
|
185
|
-
class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
|
186
|
-
@property
|
187
|
-
def _execution_device(self) -> torch.device:
|
188
|
-
return torch.device('cuda')
|
189
|
-
|
190
|
-
@property
|
191
|
-
def device(self) -> torch.device:
|
192
|
-
return torch.device('cuda')
|
193
|
-
|
194
|
-
@torch.no_grad()
|
195
|
-
def __call__(
|
196
|
-
self,
|
197
|
-
prompt: Union[str, List[str]] = None,
|
198
|
-
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
199
|
-
strength: float = 0.8,
|
200
|
-
num_inference_steps: Optional[int] = 50,
|
201
|
-
guidance_scale: Optional[float] = 7.5,
|
202
|
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
203
|
-
num_images_per_prompt: Optional[int] = 1,
|
204
|
-
eta: Optional[float] = 0.0,
|
205
|
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
206
|
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
207
|
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
208
|
-
output_type: Optional[str] = "pil",
|
209
|
-
return_dict: bool = True,
|
210
|
-
callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
|
211
|
-
callback_steps: int = 1,
|
212
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
213
|
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
214
|
-
pooled_output: Optional[torch.FloatTensor] = None,
|
215
|
-
crop_coord: Optional[torch.FloatTensor] = None,
|
216
|
-
**kwargs
|
217
|
-
):
|
218
|
-
# 1. Check inputs. Raise error if not correct
|
219
|
-
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
220
|
-
|
221
|
-
# 2. Define call parameters
|
222
|
-
if prompt is not None and isinstance(prompt, str):
|
223
|
-
batch_size = 1
|
224
|
-
elif prompt is not None and isinstance(prompt, list):
|
225
|
-
batch_size = len(prompt)
|
226
|
-
else:
|
227
|
-
batch_size = prompt_embeds.shape[0]
|
228
|
-
device = self._execution_device
|
229
|
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
230
|
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
231
|
-
# corresponds to doing no classifier free guidance.
|
232
|
-
do_classifier_free_guidance = guidance_scale>1.0
|
233
|
-
|
234
|
-
# 3. Encode input prompt
|
235
|
-
prompt_embeds = self._encode_prompt(
|
236
|
-
prompt,
|
237
|
-
device,
|
238
|
-
num_images_per_prompt,
|
239
|
-
do_classifier_free_guidance,
|
240
|
-
negative_prompt,
|
241
|
-
prompt_embeds=prompt_embeds,
|
242
|
-
negative_prompt_embeds=negative_prompt_embeds,
|
243
|
-
)
|
244
|
-
|
245
|
-
# 4. Preprocess image
|
246
|
-
image = self.image_processor.preprocess(image)
|
247
|
-
image = repeat(image, 'n ... -> (n b) ...', b=batch_size)
|
248
|
-
height, width = image.shape[2:]
|
249
|
-
|
250
|
-
# 5. set timesteps
|
251
|
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
252
|
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
253
|
-
latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
|
254
|
-
alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
|
255
|
-
|
256
|
-
# 6. Prepare latent variables
|
257
|
-
latents = self.prepare_latents(
|
258
|
-
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
259
|
-
).to(self.unet.dtype)
|
260
|
-
|
261
|
-
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
262
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
263
|
-
|
264
|
-
# SDXL inputs
|
265
|
-
if pooled_output is not None:
|
266
|
-
if crop_coord is None:
|
267
|
-
crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
|
268
|
-
else:
|
269
|
-
crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
|
270
|
-
crop_info = crop_info.to(device).repeat(batch_size*num_images_per_prompt, 1)
|
271
|
-
pooled_output = pooled_output.to(device)
|
272
|
-
|
273
|
-
if do_classifier_free_guidance:
|
274
|
-
crop_info = torch.cat([crop_info, crop_info], dim=0)
|
275
|
-
|
276
|
-
# 8. Denoising loop
|
277
|
-
num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
|
278
|
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
279
|
-
for i, t in enumerate(timesteps):
|
280
|
-
# expand the latents if we are doing classifier free guidance
|
281
|
-
latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
|
282
|
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
283
|
-
|
284
|
-
# predict the noise residual
|
285
|
-
if pooled_output is None:
|
286
|
-
if isinstance(self.unet, PixArtTransformer2DModel):
|
287
|
-
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
288
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
289
|
-
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
290
|
-
else:
|
291
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
292
|
-
cross_attention_kwargs=cross_attention_kwargs, ).sample
|
293
|
-
else:
|
294
|
-
added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
|
295
|
-
# predict the noise residual
|
296
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
|
297
|
-
cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
|
298
|
-
|
299
|
-
# perform guidance
|
300
|
-
if do_classifier_free_guidance:
|
301
|
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
302
|
-
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
303
|
-
|
304
|
-
# x_t -> x_0
|
305
|
-
alpha_prod_t = alphas_cumprod[t.long()]
|
306
|
-
beta_prod_t = 1-alpha_prod_t
|
307
|
-
latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
|
308
|
-
|
309
|
-
# compute the previous noisy sample x_t -> x_t-1
|
310
|
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
311
|
-
|
312
|
-
# call the callback, if provided
|
313
|
-
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
314
|
-
progress_bar.update()
|
315
|
-
if callback is not None and i%callback_steps == 0:
|
316
|
-
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
317
|
-
if latents is None:
|
318
|
-
return None
|
319
|
-
|
320
|
-
latents = latents.to(dtype=self.vae.dtype)
|
321
|
-
if not output_type == "latent":
|
322
|
-
image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
|
323
|
-
else:
|
324
|
-
image = latents
|
325
|
-
has_nsfw_concept = None
|
326
|
-
|
327
|
-
do_denormalize = [True]*image.shape[0]
|
328
|
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
329
|
-
|
330
|
-
# Offload last model to CPU
|
331
|
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
332
|
-
self.final_offload_hook.offload()
|
333
|
-
|
334
|
-
if not return_dict:
|
335
|
-
return (image, has_nsfw_concept)
|
336
|
-
|
337
|
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
338
|
-
|
339
|
-
class HookPipe_Inpaint(StableDiffusionInpaintPipelineLegacy):
|
340
|
-
@property
|
341
|
-
def _execution_device(self) -> torch.device:
|
342
|
-
return torch.device('cuda')
|
343
|
-
|
344
|
-
@property
|
345
|
-
def device(self) -> torch.device:
|
346
|
-
return torch.device('cuda')
|
347
|
-
|
348
|
-
@torch.no_grad()
|
349
|
-
def __call__(
|
350
|
-
self,
|
351
|
-
prompt: Union[str, List[str]] = None,
|
352
|
-
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
353
|
-
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
354
|
-
strength: float = 0.8,
|
355
|
-
num_inference_steps: Optional[int] = 50,
|
356
|
-
guidance_scale: Optional[float] = 7.5,
|
357
|
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
358
|
-
num_images_per_prompt: Optional[int] = 1,
|
359
|
-
add_predicted_noise: Optional[bool] = False,
|
360
|
-
eta: Optional[float] = 0.0,
|
361
|
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
362
|
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
363
|
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
364
|
-
output_type: Optional[str] = "pil",
|
365
|
-
return_dict: bool = True,
|
366
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
367
|
-
callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
|
368
|
-
callback_steps: int = 1,
|
369
|
-
**kwargs
|
370
|
-
):
|
371
|
-
# 1. Check inputs
|
372
|
-
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
373
|
-
|
374
|
-
# 2. Define call parameters
|
375
|
-
if prompt is not None and isinstance(prompt, str):
|
376
|
-
batch_size = 1
|
377
|
-
elif prompt is not None and isinstance(prompt, list):
|
378
|
-
batch_size = len(prompt)
|
379
|
-
else:
|
380
|
-
batch_size = prompt_embeds.shape[0]
|
381
|
-
|
382
|
-
device = self._execution_device
|
383
|
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
384
|
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
385
|
-
# corresponds to doing no classifier free guidance.
|
386
|
-
do_classifier_free_guidance = guidance_scale>1.0
|
387
|
-
|
388
|
-
# 3. Encode input prompt
|
389
|
-
prompt_embeds = self._encode_prompt(
|
390
|
-
prompt,
|
391
|
-
device,
|
392
|
-
num_images_per_prompt,
|
393
|
-
do_classifier_free_guidance,
|
394
|
-
negative_prompt,
|
395
|
-
prompt_embeds=prompt_embeds,
|
396
|
-
negative_prompt_embeds=negative_prompt_embeds,
|
397
|
-
)
|
398
|
-
|
399
|
-
# 4. Preprocess image and mask
|
400
|
-
if not isinstance(image, torch.FloatTensor):
|
401
|
-
image = preprocess_image(image, batch_size).to(self._execution_device)
|
402
|
-
|
403
|
-
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
|
404
|
-
|
405
|
-
# 5. set timesteps
|
406
|
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
407
|
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
408
|
-
latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
|
409
|
-
alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
|
410
|
-
|
411
|
-
# 6. Prepare latent variables
|
412
|
-
# encode the init image into latents and scale the latents
|
413
|
-
latents, init_latents_orig, noise = self.prepare_latents(
|
414
|
-
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
415
|
-
)
|
416
|
-
|
417
|
-
# 7. Prepare mask latent
|
418
|
-
mask = mask_image.to(device=self._execution_device, dtype=latents.dtype)
|
419
|
-
mask = torch.cat([mask]*num_images_per_prompt)
|
420
|
-
|
421
|
-
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
422
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
423
|
-
|
424
|
-
# 9. Denoising loop
|
425
|
-
num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
|
426
|
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
427
|
-
for i, t in enumerate(timesteps):
|
428
|
-
# expand the latents if we are doing classifier free guidance
|
429
|
-
latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
|
430
|
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
431
|
-
|
432
|
-
# predict the noise residual
|
433
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask).sample
|
434
|
-
|
435
|
-
# perform guidance
|
436
|
-
if do_classifier_free_guidance:
|
437
|
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
438
|
-
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
439
|
-
|
440
|
-
# masking
|
441
|
-
if add_predicted_noise:
|
442
|
-
init_latents_proper = self.scheduler.add_noise(
|
443
|
-
init_latents_orig, noise_pred_uncond, torch.tensor([t])
|
444
|
-
)
|
445
|
-
else:
|
446
|
-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
447
|
-
|
448
|
-
# x_t-1 -> x_0
|
449
|
-
alpha_prod_t = alphas_cumprod[t.long()]
|
450
|
-
beta_prod_t = 1-alpha_prod_t
|
451
|
-
latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
|
452
|
-
# normalize latents_x0 to keep the contrast consistent with the original image
|
453
|
-
latents_x0 = (latents_x0-latents_x0.mean())/latents_x0.std()*init_latents_orig.std()+init_latents_orig.mean()
|
454
|
-
latents_x0 = (init_latents_orig*mask)+(latents_x0*(1-mask))
|
455
|
-
|
456
|
-
# compute the previous noisy sample x_t -> x_t-1
|
457
|
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
458
|
-
|
459
|
-
latents = (init_latents_proper*mask)+(latents*(1-mask))
|
460
|
-
|
461
|
-
# call the callback, if provided
|
462
|
-
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
463
|
-
progress_bar.update()
|
464
|
-
if callback is not None and i%callback_steps == 0:
|
465
|
-
latents = callback(i, t, num_inference_steps, latents_x0, latents)
|
466
|
-
if latents is None:
|
467
|
-
return None
|
468
|
-
|
469
|
-
# use original latents corresponding to unmasked portions of the image
|
470
|
-
latents = (init_latents_orig*mask)+(latents*(1-mask))
|
471
|
-
latents = latents.to(dtype=self.vae.dtype)
|
472
|
-
if not output_type == "latent":
|
473
|
-
image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
|
474
|
-
else:
|
475
|
-
image = latents
|
476
|
-
has_nsfw_concept = None
|
477
|
-
|
478
|
-
do_denormalize = [True]*image.shape[0]
|
479
|
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
480
|
-
|
481
|
-
# Offload last model to CPU
|
482
|
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
483
|
-
self.final_offload_hook.offload()
|
484
|
-
|
485
|
-
if not return_dict:
|
486
|
-
return (image, has_nsfw_concept)
|
487
|
-
|
488
|
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
489
|
-
|
490
|
-
class HCPSDPipe(StableDiffusionImg2ImgPipeline):
|
491
|
-
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor):
|
492
|
-
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)
|
493
|
-
|
494
|
-
self.mask_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False, resample='nearest')
|
495
|
-
|
496
|
-
@property
|
497
|
-
def _execution_device(self) -> torch.device:
|
498
|
-
return torch.device('cuda')
|
499
|
-
|
500
|
-
@property
|
501
|
-
def device(self) -> torch.device:
|
502
|
-
return torch.device('cuda')
|
503
|
-
|
504
|
-
def preprocess_images(self, image, mask, batch_size, num_images_per_prompt):
|
505
|
-
if image is not None:
|
506
|
-
if not isinstance(image, torch.FloatTensor):
|
507
|
-
image = self.image_processor.preprocess(image, batch_size)
|
508
|
-
if image.shape[0] == 1:
|
509
|
-
image = repeat(image, 'n ... -> (n b) ...', b=batch_size)
|
510
|
-
if mask is not None:
|
511
|
-
mask = preprocess_mask(mask, batch_size, self.vae_scale_factor)
|
512
|
-
mask = mask.to(device=self._execution_device, dtype=self.unet.dtype)
|
513
|
-
mask = torch.cat([mask]*num_images_per_prompt)
|
514
|
-
return image, mask
|
515
|
-
|
516
|
-
def process_timesteps(self, num_inference_steps, strength, batch_size, num_images_per_prompt):
|
517
|
-
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
518
|
-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self._execution_device)
|
519
|
-
latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
|
520
|
-
alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
|
521
|
-
return timesteps, num_inference_steps, latent_timestep, alphas_cumprod
|
522
|
-
|
523
|
-
@torch.no_grad()
|
524
|
-
def __call__(
|
525
|
-
self,
|
526
|
-
prompt: Union[str, List[str]] = None,
|
527
|
-
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
528
|
-
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
529
|
-
strength: float = 0.8,
|
530
|
-
num_inference_steps: Optional[int] = 50,
|
531
|
-
guidance_scale: Optional[float] = 7.5,
|
532
|
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
533
|
-
num_images_per_prompt: Optional[int] = 1,
|
534
|
-
add_predicted_noise: Optional[bool] = False,
|
535
|
-
eta: Optional[float] = 0.0,
|
536
|
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
537
|
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
538
|
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
539
|
-
output_type: Optional[str] = "pil",
|
540
|
-
return_dict: bool = True,
|
541
|
-
callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
|
542
|
-
callback_steps: int = 1,
|
543
|
-
**kwargs
|
544
|
-
):
|
545
|
-
# 1. Check inputs
|
546
|
-
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
547
|
-
|
548
|
-
# 2. Define call parameters
|
549
|
-
if prompt is not None and isinstance(prompt, str):
|
550
|
-
batch_size = 1
|
551
|
-
elif prompt is not None and isinstance(prompt, list):
|
552
|
-
batch_size = len(prompt)
|
553
|
-
else:
|
554
|
-
batch_size = prompt_embeds.shape[0]
|
555
|
-
|
556
|
-
device = self._execution_device
|
557
|
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
558
|
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
559
|
-
# corresponds to doing no classifier free guidance.
|
560
|
-
do_classifier_free_guidance = guidance_scale>1.0
|
561
|
-
|
562
|
-
# 3. Encode input prompt
|
563
|
-
prompt_embeds = self._encode_prompt(
|
564
|
-
prompt,
|
565
|
-
device,
|
566
|
-
num_images_per_prompt,
|
567
|
-
do_classifier_free_guidance,
|
568
|
-
negative_prompt,
|
569
|
-
prompt_embeds=prompt_embeds,
|
570
|
-
negative_prompt_embeds=negative_prompt_embeds,
|
571
|
-
)
|
572
|
-
|
573
|
-
# 4. Preprocess image and mask
|
574
|
-
image, mask = self.preprocess_images(image, mask_image, batch_size, num_images_per_prompt)
|
575
|
-
|
576
|
-
# 5. set timesteps
|
577
|
-
timesteps, num_inference_steps, latent_timestep, alphas_cumprod = self.process_timesteps(
|
578
|
-
num_inference_steps, strength, batch_size, num_images_per_prompt)
|
579
|
-
|
580
|
-
# 6. Prepare latent variables
|
581
|
-
# encode the init image into latents and scale the latents
|
582
|
-
latents, init_latents_orig, noise = self.prepare_latents(
|
583
|
-
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
584
|
-
)
|
585
|
-
|
586
|
-
# 7. Prepare mask latent
|
587
|
-
mask = mask_image.to(device=self._execution_device, dtype=latents.dtype)
|
588
|
-
mask = torch.cat([mask]*num_images_per_prompt)
|
589
|
-
|
590
|
-
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
591
|
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
592
|
-
|
593
|
-
# 9. Denoising loop
|
594
|
-
num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
|
595
|
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
596
|
-
for i, t in enumerate(timesteps):
|
597
|
-
# expand the latents if we are doing classifier free guidance
|
598
|
-
latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
|
599
|
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
600
|
-
|
601
|
-
# predict the noise residual
|
602
|
-
noise_pred = self.unet(latent_model_input, t, prompt_embeds).sample
|
603
|
-
|
604
|
-
# perform guidance
|
605
|
-
if do_classifier_free_guidance:
|
606
|
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
607
|
-
noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
|
608
|
-
|
609
|
-
# masking
|
610
|
-
if add_predicted_noise:
|
611
|
-
init_latents_proper = self.scheduler.add_noise(
|
612
|
-
init_latents_orig, noise_pred_uncond, torch.tensor([t])
|
613
|
-
)
|
614
|
-
else:
|
615
|
-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
616
|
-
|
617
|
-
# x_t-1 -> x_0
|
618
|
-
alpha_prod_t = alphas_cumprod[t.long()]
|
619
|
-
beta_prod_t = 1-alpha_prod_t
|
620
|
-
latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
|
621
|
-
# normalize latents_x0 to keep the contrast consistent with the original image
|
622
|
-
latents_x0 = (latents_x0-latents_x0.mean())/latents_x0.std()*init_latents_orig.std()+init_latents_orig.mean()
|
623
|
-
latents_x0 = (init_latents_orig*mask)+(latents_x0*(1-mask))
|
624
|
-
|
625
|
-
# compute the previous noisy sample x_t -> x_t-1
|
626
|
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
627
|
-
|
628
|
-
latents = (init_latents_proper*mask)+(latents*(1-mask))
|
629
|
-
|
630
|
-
# call the callback, if provided
|
631
|
-
if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
|
632
|
-
progress_bar.update()
|
633
|
-
if callback is not None and i%callback_steps == 0:
|
634
|
-
if callback(i, t, num_inference_steps, latents_x0):
|
635
|
-
return None
|
636
|
-
|
637
|
-
# use original latents corresponding to unmasked portions of the image
|
638
|
-
latents = (init_latents_orig*mask)+(latents*(1-mask))
|
639
|
-
|
640
|
-
if not output_type == "latent":
|
641
|
-
image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
|
642
|
-
else:
|
643
|
-
image = latents
|
644
|
-
has_nsfw_concept = None
|
645
|
-
|
646
|
-
do_denormalize = [True]*image.shape[0]
|
647
|
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
648
|
-
|
649
|
-
# Offload last model to CPU
|
650
|
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
651
|
-
self.final_offload_hook.offload()
|
652
|
-
|
653
|
-
if not return_dict:
|
654
|
-
return (image, has_nsfw_concept)
|
655
|
-
|
656
|
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|