hcpdiff 2.3.1__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/format/__init__.py +2 -2
  3. hcpdiff/ckpt_manager/format/diffusers.py +19 -4
  4. hcpdiff/ckpt_manager/format/emb.py +8 -3
  5. hcpdiff/ckpt_manager/format/lora_webui.py +1 -1
  6. hcpdiff/ckpt_manager/format/sd_single.py +28 -5
  7. hcpdiff/data/cache/vae.py +10 -2
  8. hcpdiff/data/handler/text.py +15 -14
  9. hcpdiff/diffusion/sampler/__init__.py +2 -1
  10. hcpdiff/diffusion/sampler/base.py +17 -6
  11. hcpdiff/diffusion/sampler/diffusers.py +4 -3
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +5 -14
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +7 -6
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +4 -4
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +3 -3
  16. hcpdiff/diffusion/sampler/timer/__init__.py +2 -0
  17. hcpdiff/diffusion/sampler/timer/base.py +26 -0
  18. hcpdiff/diffusion/sampler/timer/shift.py +49 -0
  19. hcpdiff/easy/__init__.py +2 -1
  20. hcpdiff/easy/cfg/sd15_train.py +1 -3
  21. hcpdiff/easy/model/__init__.py +1 -1
  22. hcpdiff/easy/model/loader.py +33 -11
  23. hcpdiff/easy/sampler.py +8 -1
  24. hcpdiff/loss/__init__.py +4 -3
  25. hcpdiff/loss/charbonnier.py +17 -0
  26. hcpdiff/loss/vlb.py +2 -2
  27. hcpdiff/loss/weighting.py +29 -11
  28. hcpdiff/models/__init__.py +1 -1
  29. hcpdiff/models/cfg_context.py +5 -3
  30. hcpdiff/models/compose/__init__.py +2 -1
  31. hcpdiff/models/compose/compose_hook.py +69 -67
  32. hcpdiff/models/compose/compose_textencoder.py +59 -45
  33. hcpdiff/models/compose/compose_tokenizer.py +48 -11
  34. hcpdiff/models/compose/flux.py +75 -0
  35. hcpdiff/models/compose/sdxl.py +86 -0
  36. hcpdiff/models/text_emb_ex.py +13 -9
  37. hcpdiff/models/textencoder_ex.py +8 -38
  38. hcpdiff/models/wrapper/__init__.py +2 -1
  39. hcpdiff/models/wrapper/flux.py +75 -0
  40. hcpdiff/models/wrapper/pixart.py +13 -1
  41. hcpdiff/models/wrapper/sd.py +17 -8
  42. hcpdiff/parser/embpt.py +7 -7
  43. hcpdiff/utils/net_utils.py +22 -12
  44. hcpdiff/workflow/__init__.py +1 -1
  45. hcpdiff/workflow/diffusion.py +145 -18
  46. hcpdiff/workflow/text.py +49 -18
  47. hcpdiff/workflow/vae.py +10 -2
  48. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/METADATA +1 -1
  49. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/RECORD +53 -49
  50. hcpdiff/models/compose/sdxl_composer.py +0 -39
  51. hcpdiff/utils/inpaint_pipe.py +0 -790
  52. hcpdiff/utils/pipe_hook.py +0 -656
  53. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/WHEEL +0 -0
  54. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/entry_points.txt +0 -0
  55. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/licenses/LICENSE +0 -0
  56. {hcpdiff-2.3.1.dist-info → hcpdiff-2.4.dist-info}/top_level.txt +0 -0
@@ -1,656 +0,0 @@
1
- from typing import Union, List, Optional, Callable, Dict, Any
2
-
3
- import PIL
4
- import torch
5
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, PixArtTransformer2DModel
6
- from diffusers.image_processor import VaeImageProcessor
7
- from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
8
- from .inpaint_pipe import preprocess_mask, preprocess_image, StableDiffusionInpaintPipelineLegacy
9
- from einops import repeat
10
-
11
- class HookPipe_T2I(StableDiffusionPipeline):
12
- @property
13
- def _execution_device(self) -> torch.device:
14
- return torch.device('cuda')
15
-
16
- @property
17
- def device(self) -> torch.device:
18
- return torch.device('cuda')
19
-
20
- def proc_prompt(self, device, num_inference_steps, prompt_embeds = None, negative_prompt_embeds = None) -> List[torch.Tensor]:
21
- if not isinstance(prompt_embeds, list): # to emb for each step
22
- prompt_embeds = [prompt_embeds]*num_inference_steps
23
- if not isinstance(negative_prompt_embeds, list): # to emb for each step
24
- negative_prompt_embeds = [negative_prompt_embeds]*num_inference_steps
25
-
26
- prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in prompt_embeds]
27
- negative_prompt_embeds = [p.to(dtype=self.text_encoder.dtype, device=device) for p in negative_prompt_embeds]
28
-
29
- prompt_embeds = [torch.cat([emb_neg, emb_pos]) for emb_pos, emb_neg in zip(prompt_embeds, negative_prompt_embeds)]
30
- return prompt_embeds # List[emb_step_i]*num_inference_steps
31
-
32
- @torch.no_grad()
33
- def __call__(
34
- self,
35
- prompt: Union[str, List[str]] = None,
36
- height: Optional[int] = None,
37
- width: Optional[int] = None,
38
- num_inference_steps: int = 50,
39
- guidance_scale: float = 7.5,
40
- negative_prompt: Optional[Union[str, List[str]]] = None,
41
- eta: float = 0.0,
42
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
43
- latents: Optional[torch.FloatTensor] = None,
44
- prompt_embeds: Optional[torch.FloatTensor] = None,
45
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
46
- output_type: Optional[str] = "pil",
47
- return_dict: bool = True,
48
- callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
49
- callback_steps: int = 1,
50
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
51
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
52
- pooled_output: Optional[torch.FloatTensor] = None,
53
- crop_coord: Optional[torch.FloatTensor] = None,
54
- **kwargs
55
- ):
56
- # 0. Default height and width to unet
57
- height = height or self.unet.config.sample_size*self.vae_scale_factor
58
- width = width or self.unet.config.sample_size*self.vae_scale_factor
59
-
60
- # 1. Check inputs. Raise error if not correct
61
- # self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
62
-
63
- # 2. Define call parameters
64
- if prompt is not None and isinstance(prompt, str):
65
- batch_size = 1
66
- elif prompt is not None and isinstance(prompt, list):
67
- batch_size = len(prompt)
68
- elif isinstance(prompt_embeds, list):
69
- batch_size = prompt_embeds[0].shape[0]
70
- else:
71
- batch_size = prompt_embeds.shape[0]
72
-
73
- device = self._execution_device
74
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
75
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
76
- # corresponds to doing no classifier free guidance.
77
- do_classifier_free_guidance = guidance_scale>1.0
78
-
79
- # 3. Encode input prompt
80
- prompt_embeds = self.proc_prompt(device, num_inference_steps,
81
- prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds)
82
-
83
- # 4. Prepare timesteps
84
- self.scheduler.set_timesteps(num_inference_steps, device=device)
85
- timesteps = self.scheduler.timesteps
86
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
87
-
88
- # 5. Prepare latent variables
89
- num_channels_latents = self.unet.config.in_channels
90
- latents = self.prepare_latents(
91
- batch_size,
92
- num_channels_latents,
93
- height,
94
- width,
95
- prompt_embeds[0].dtype,
96
- device,
97
- generator,
98
- latents,
99
- )
100
-
101
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
102
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
103
-
104
- # SDXL inputs
105
- if pooled_output is not None:
106
- if crop_coord is None:
107
- crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
108
- else:
109
- crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
110
- crop_info = crop_info.to(device).repeat(batch_size, 1)
111
- pooled_output = pooled_output.to(device)
112
-
113
- if do_classifier_free_guidance:
114
- crop_info = torch.cat([crop_info, crop_info], dim=0)
115
-
116
- # 7. Denoising loop
117
- num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
118
- with self.progress_bar(total=num_inference_steps) as progress_bar:
119
- for i, t in enumerate(timesteps):
120
- # expand the latents if we are doing classifier free guidance
121
- latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
122
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
123
-
124
- if pooled_output is None:
125
- if isinstance(self.unet, PixArtTransformer2DModel):
126
- added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
127
- noise_pred = self.unet(latent_model_input, timestep=t.repeat(latent_model_input.shape[0]), encoder_hidden_states=prompt_embeds[i],
128
- encoder_attention_mask=encoder_attention_mask,
129
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
130
- else:
131
- noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
132
- encoder_attention_mask=encoder_attention_mask,
133
- cross_attention_kwargs=cross_attention_kwargs).sample
134
- else:
135
- added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
136
- # predict the noise residual
137
- noise_pred = self.unet(latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds[i],
138
- encoder_attention_mask=encoder_attention_mask,
139
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
140
-
141
- # perform guidance
142
- if do_classifier_free_guidance:
143
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
144
- noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
145
-
146
- # learned sigma
147
- if self.unet.config.out_channels // 2 == num_channels_latents:
148
- noise_pred = noise_pred.chunk(2, dim=1)[0]
149
-
150
- # x_t -> x_0
151
- alpha_prod_t = alphas_cumprod[t.long()]
152
- beta_prod_t = 1-alpha_prod_t
153
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
154
-
155
- # compute the previous noisy sample x_t -> x_t-1
156
- sc_out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
157
- latents = sc_out.prev_sample
158
-
159
- # call the callback, if provided
160
- if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
161
- progress_bar.update()
162
- if callback is not None and i%callback_steps == 0:
163
- latents = callback(i, t, num_inference_steps, latents_x0, latents)
164
- if latents is None:
165
- return None
166
-
167
- latents = latents.to(dtype=self.vae.dtype)
168
- if not output_type == "latent":
169
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
170
- else:
171
- image = latents
172
-
173
- do_denormalize = [True]*image.shape[0]
174
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
175
-
176
- # Offload last model to CPU
177
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
178
- self.final_offload_hook.offload()
179
-
180
- if not return_dict:
181
- return (image, None)
182
-
183
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
184
-
185
- class HookPipe_I2I(StableDiffusionImg2ImgPipeline):
186
- @property
187
- def _execution_device(self) -> torch.device:
188
- return torch.device('cuda')
189
-
190
- @property
191
- def device(self) -> torch.device:
192
- return torch.device('cuda')
193
-
194
- @torch.no_grad()
195
- def __call__(
196
- self,
197
- prompt: Union[str, List[str]] = None,
198
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
199
- strength: float = 0.8,
200
- num_inference_steps: Optional[int] = 50,
201
- guidance_scale: Optional[float] = 7.5,
202
- negative_prompt: Optional[Union[str, List[str]]] = None,
203
- num_images_per_prompt: Optional[int] = 1,
204
- eta: Optional[float] = 0.0,
205
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
206
- prompt_embeds: Optional[torch.FloatTensor] = None,
207
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
208
- output_type: Optional[str] = "pil",
209
- return_dict: bool = True,
210
- callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
211
- callback_steps: int = 1,
212
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
213
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
214
- pooled_output: Optional[torch.FloatTensor] = None,
215
- crop_coord: Optional[torch.FloatTensor] = None,
216
- **kwargs
217
- ):
218
- # 1. Check inputs. Raise error if not correct
219
- self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
220
-
221
- # 2. Define call parameters
222
- if prompt is not None and isinstance(prompt, str):
223
- batch_size = 1
224
- elif prompt is not None and isinstance(prompt, list):
225
- batch_size = len(prompt)
226
- else:
227
- batch_size = prompt_embeds.shape[0]
228
- device = self._execution_device
229
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
230
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
231
- # corresponds to doing no classifier free guidance.
232
- do_classifier_free_guidance = guidance_scale>1.0
233
-
234
- # 3. Encode input prompt
235
- prompt_embeds = self._encode_prompt(
236
- prompt,
237
- device,
238
- num_images_per_prompt,
239
- do_classifier_free_guidance,
240
- negative_prompt,
241
- prompt_embeds=prompt_embeds,
242
- negative_prompt_embeds=negative_prompt_embeds,
243
- )
244
-
245
- # 4. Preprocess image
246
- image = self.image_processor.preprocess(image)
247
- image = repeat(image, 'n ... -> (n b) ...', b=batch_size)
248
- height, width = image.shape[2:]
249
-
250
- # 5. set timesteps
251
- self.scheduler.set_timesteps(num_inference_steps, device=device)
252
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
253
- latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
254
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
255
-
256
- # 6. Prepare latent variables
257
- latents = self.prepare_latents(
258
- image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
259
- ).to(self.unet.dtype)
260
-
261
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
262
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
263
-
264
- # SDXL inputs
265
- if pooled_output is not None:
266
- if crop_coord is None:
267
- crop_info = torch.tensor([height, width, 0, 0, height, width], dtype=torch.float)
268
- else:
269
- crop_info = torch.tensor([height, width, *crop_coord], dtype=torch.float)
270
- crop_info = crop_info.to(device).repeat(batch_size*num_images_per_prompt, 1)
271
- pooled_output = pooled_output.to(device)
272
-
273
- if do_classifier_free_guidance:
274
- crop_info = torch.cat([crop_info, crop_info], dim=0)
275
-
276
- # 8. Denoising loop
277
- num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
278
- with self.progress_bar(total=num_inference_steps) as progress_bar:
279
- for i, t in enumerate(timesteps):
280
- # expand the latents if we are doing classifier free guidance
281
- latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
282
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
283
-
284
- # predict the noise residual
285
- if pooled_output is None:
286
- if isinstance(self.unet, PixArtTransformer2DModel):
287
- added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
288
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
289
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
290
- else:
291
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
292
- cross_attention_kwargs=cross_attention_kwargs, ).sample
293
- else:
294
- added_cond_kwargs = {"text_embeds":pooled_output, "time_ids":crop_info}
295
- # predict the noise residual
296
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask,
297
- cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
298
-
299
- # perform guidance
300
- if do_classifier_free_guidance:
301
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
302
- noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
303
-
304
- # x_t -> x_0
305
- alpha_prod_t = alphas_cumprod[t.long()]
306
- beta_prod_t = 1-alpha_prod_t
307
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
308
-
309
- # compute the previous noisy sample x_t -> x_t-1
310
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
311
-
312
- # call the callback, if provided
313
- if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
314
- progress_bar.update()
315
- if callback is not None and i%callback_steps == 0:
316
- latents = callback(i, t, num_inference_steps, latents_x0, latents)
317
- if latents is None:
318
- return None
319
-
320
- latents = latents.to(dtype=self.vae.dtype)
321
- if not output_type == "latent":
322
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
323
- else:
324
- image = latents
325
- has_nsfw_concept = None
326
-
327
- do_denormalize = [True]*image.shape[0]
328
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
329
-
330
- # Offload last model to CPU
331
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
332
- self.final_offload_hook.offload()
333
-
334
- if not return_dict:
335
- return (image, has_nsfw_concept)
336
-
337
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
338
-
339
- class HookPipe_Inpaint(StableDiffusionInpaintPipelineLegacy):
340
- @property
341
- def _execution_device(self) -> torch.device:
342
- return torch.device('cuda')
343
-
344
- @property
345
- def device(self) -> torch.device:
346
- return torch.device('cuda')
347
-
348
- @torch.no_grad()
349
- def __call__(
350
- self,
351
- prompt: Union[str, List[str]] = None,
352
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
353
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
354
- strength: float = 0.8,
355
- num_inference_steps: Optional[int] = 50,
356
- guidance_scale: Optional[float] = 7.5,
357
- negative_prompt: Optional[Union[str, List[str]]] = None,
358
- num_images_per_prompt: Optional[int] = 1,
359
- add_predicted_noise: Optional[bool] = False,
360
- eta: Optional[float] = 0.0,
361
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
362
- prompt_embeds: Optional[torch.FloatTensor] = None,
363
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
364
- output_type: Optional[str] = "pil",
365
- return_dict: bool = True,
366
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
367
- callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
368
- callback_steps: int = 1,
369
- **kwargs
370
- ):
371
- # 1. Check inputs
372
- self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
373
-
374
- # 2. Define call parameters
375
- if prompt is not None and isinstance(prompt, str):
376
- batch_size = 1
377
- elif prompt is not None and isinstance(prompt, list):
378
- batch_size = len(prompt)
379
- else:
380
- batch_size = prompt_embeds.shape[0]
381
-
382
- device = self._execution_device
383
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
384
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
385
- # corresponds to doing no classifier free guidance.
386
- do_classifier_free_guidance = guidance_scale>1.0
387
-
388
- # 3. Encode input prompt
389
- prompt_embeds = self._encode_prompt(
390
- prompt,
391
- device,
392
- num_images_per_prompt,
393
- do_classifier_free_guidance,
394
- negative_prompt,
395
- prompt_embeds=prompt_embeds,
396
- negative_prompt_embeds=negative_prompt_embeds,
397
- )
398
-
399
- # 4. Preprocess image and mask
400
- if not isinstance(image, torch.FloatTensor):
401
- image = preprocess_image(image, batch_size).to(self._execution_device)
402
-
403
- mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
404
-
405
- # 5. set timesteps
406
- self.scheduler.set_timesteps(num_inference_steps, device=device)
407
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
408
- latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
409
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
410
-
411
- # 6. Prepare latent variables
412
- # encode the init image into latents and scale the latents
413
- latents, init_latents_orig, noise = self.prepare_latents(
414
- image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
415
- )
416
-
417
- # 7. Prepare mask latent
418
- mask = mask_image.to(device=self._execution_device, dtype=latents.dtype)
419
- mask = torch.cat([mask]*num_images_per_prompt)
420
-
421
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
422
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
423
-
424
- # 9. Denoising loop
425
- num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
426
- with self.progress_bar(total=num_inference_steps) as progress_bar:
427
- for i, t in enumerate(timesteps):
428
- # expand the latents if we are doing classifier free guidance
429
- latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
430
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
431
-
432
- # predict the noise residual
433
- noise_pred = self.unet(latent_model_input, t, prompt_embeds, encoder_attention_mask=encoder_attention_mask).sample
434
-
435
- # perform guidance
436
- if do_classifier_free_guidance:
437
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
438
- noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
439
-
440
- # masking
441
- if add_predicted_noise:
442
- init_latents_proper = self.scheduler.add_noise(
443
- init_latents_orig, noise_pred_uncond, torch.tensor([t])
444
- )
445
- else:
446
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
447
-
448
- # x_t-1 -> x_0
449
- alpha_prod_t = alphas_cumprod[t.long()]
450
- beta_prod_t = 1-alpha_prod_t
451
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
452
- # normalize latents_x0 to keep the contrast consistent with the original image
453
- latents_x0 = (latents_x0-latents_x0.mean())/latents_x0.std()*init_latents_orig.std()+init_latents_orig.mean()
454
- latents_x0 = (init_latents_orig*mask)+(latents_x0*(1-mask))
455
-
456
- # compute the previous noisy sample x_t -> x_t-1
457
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
458
-
459
- latents = (init_latents_proper*mask)+(latents*(1-mask))
460
-
461
- # call the callback, if provided
462
- if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
463
- progress_bar.update()
464
- if callback is not None and i%callback_steps == 0:
465
- latents = callback(i, t, num_inference_steps, latents_x0, latents)
466
- if latents is None:
467
- return None
468
-
469
- # use original latents corresponding to unmasked portions of the image
470
- latents = (init_latents_orig*mask)+(latents*(1-mask))
471
- latents = latents.to(dtype=self.vae.dtype)
472
- if not output_type == "latent":
473
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
474
- else:
475
- image = latents
476
- has_nsfw_concept = None
477
-
478
- do_denormalize = [True]*image.shape[0]
479
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
480
-
481
- # Offload last model to CPU
482
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
483
- self.final_offload_hook.offload()
484
-
485
- if not return_dict:
486
- return (image, has_nsfw_concept)
487
-
488
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
489
-
490
- class HCPSDPipe(StableDiffusionImg2ImgPipeline):
491
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor):
492
- super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)
493
-
494
- self.mask_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False, resample='nearest')
495
-
496
- @property
497
- def _execution_device(self) -> torch.device:
498
- return torch.device('cuda')
499
-
500
- @property
501
- def device(self) -> torch.device:
502
- return torch.device('cuda')
503
-
504
- def preprocess_images(self, image, mask, batch_size, num_images_per_prompt):
505
- if image is not None:
506
- if not isinstance(image, torch.FloatTensor):
507
- image = self.image_processor.preprocess(image, batch_size)
508
- if image.shape[0] == 1:
509
- image = repeat(image, 'n ... -> (n b) ...', b=batch_size)
510
- if mask is not None:
511
- mask = preprocess_mask(mask, batch_size, self.vae_scale_factor)
512
- mask = mask.to(device=self._execution_device, dtype=self.unet.dtype)
513
- mask = torch.cat([mask]*num_images_per_prompt)
514
- return image, mask
515
-
516
- def process_timesteps(self, num_inference_steps, strength, batch_size, num_images_per_prompt):
517
- self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
518
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self._execution_device)
519
- latent_timestep = timesteps[:1].repeat(batch_size*num_images_per_prompt)
520
- alphas_cumprod = self.scheduler.alphas_cumprod.to(timesteps.device)
521
- return timesteps, num_inference_steps, latent_timestep, alphas_cumprod
522
-
523
- @torch.no_grad()
524
- def __call__(
525
- self,
526
- prompt: Union[str, List[str]] = None,
527
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
528
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
529
- strength: float = 0.8,
530
- num_inference_steps: Optional[int] = 50,
531
- guidance_scale: Optional[float] = 7.5,
532
- negative_prompt: Optional[Union[str, List[str]]] = None,
533
- num_images_per_prompt: Optional[int] = 1,
534
- add_predicted_noise: Optional[bool] = False,
535
- eta: Optional[float] = 0.0,
536
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
537
- prompt_embeds: Optional[torch.FloatTensor] = None,
538
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
539
- output_type: Optional[str] = "pil",
540
- return_dict: bool = True,
541
- callback: Optional[Callable[[int, int, int, torch.FloatTensor], None]] = None,
542
- callback_steps: int = 1,
543
- **kwargs
544
- ):
545
- # 1. Check inputs
546
- self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
547
-
548
- # 2. Define call parameters
549
- if prompt is not None and isinstance(prompt, str):
550
- batch_size = 1
551
- elif prompt is not None and isinstance(prompt, list):
552
- batch_size = len(prompt)
553
- else:
554
- batch_size = prompt_embeds.shape[0]
555
-
556
- device = self._execution_device
557
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
558
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
559
- # corresponds to doing no classifier free guidance.
560
- do_classifier_free_guidance = guidance_scale>1.0
561
-
562
- # 3. Encode input prompt
563
- prompt_embeds = self._encode_prompt(
564
- prompt,
565
- device,
566
- num_images_per_prompt,
567
- do_classifier_free_guidance,
568
- negative_prompt,
569
- prompt_embeds=prompt_embeds,
570
- negative_prompt_embeds=negative_prompt_embeds,
571
- )
572
-
573
- # 4. Preprocess image and mask
574
- image, mask = self.preprocess_images(image, mask_image, batch_size, num_images_per_prompt)
575
-
576
- # 5. set timesteps
577
- timesteps, num_inference_steps, latent_timestep, alphas_cumprod = self.process_timesteps(
578
- num_inference_steps, strength, batch_size, num_images_per_prompt)
579
-
580
- # 6. Prepare latent variables
581
- # encode the init image into latents and scale the latents
582
- latents, init_latents_orig, noise = self.prepare_latents(
583
- image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
584
- )
585
-
586
- # 7. Prepare mask latent
587
- mask = mask_image.to(device=self._execution_device, dtype=latents.dtype)
588
- mask = torch.cat([mask]*num_images_per_prompt)
589
-
590
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
591
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
592
-
593
- # 9. Denoising loop
594
- num_warmup_steps = len(timesteps)-num_inference_steps*self.scheduler.order
595
- with self.progress_bar(total=num_inference_steps) as progress_bar:
596
- for i, t in enumerate(timesteps):
597
- # expand the latents if we are doing classifier free guidance
598
- latent_model_input = torch.cat([latents]*2) if do_classifier_free_guidance else latents
599
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
600
-
601
- # predict the noise residual
602
- noise_pred = self.unet(latent_model_input, t, prompt_embeds).sample
603
-
604
- # perform guidance
605
- if do_classifier_free_guidance:
606
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
607
- noise_pred = noise_pred_uncond+guidance_scale*(noise_pred_text-noise_pred_uncond)
608
-
609
- # masking
610
- if add_predicted_noise:
611
- init_latents_proper = self.scheduler.add_noise(
612
- init_latents_orig, noise_pred_uncond, torch.tensor([t])
613
- )
614
- else:
615
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
616
-
617
- # x_t-1 -> x_0
618
- alpha_prod_t = alphas_cumprod[t.long()]
619
- beta_prod_t = 1-alpha_prod_t
620
- latents_x0 = (latents-beta_prod_t**(0.5)*noise_pred)/alpha_prod_t**(0.5) # approximate x_0
621
- # normalize latents_x0 to keep the contrast consistent with the original image
622
- latents_x0 = (latents_x0-latents_x0.mean())/latents_x0.std()*init_latents_orig.std()+init_latents_orig.mean()
623
- latents_x0 = (init_latents_orig*mask)+(latents_x0*(1-mask))
624
-
625
- # compute the previous noisy sample x_t -> x_t-1
626
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
627
-
628
- latents = (init_latents_proper*mask)+(latents*(1-mask))
629
-
630
- # call the callback, if provided
631
- if i == len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order == 0):
632
- progress_bar.update()
633
- if callback is not None and i%callback_steps == 0:
634
- if callback(i, t, num_inference_steps, latents_x0):
635
- return None
636
-
637
- # use original latents corresponding to unmasked portions of the image
638
- latents = (init_latents_orig*mask)+(latents*(1-mask))
639
-
640
- if not output_type == "latent":
641
- image = self.vae.decode(latents/self.vae.config.scaling_factor, return_dict=False)[0]
642
- else:
643
- image = latents
644
- has_nsfw_concept = None
645
-
646
- do_denormalize = [True]*image.shape[0]
647
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
648
-
649
- # Offload last model to CPU
650
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
651
- self.final_offload_hook.offload()
652
-
653
- if not return_dict:
654
- return (image, has_nsfw_concept)
655
-
656
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
File without changes