diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from packaging import version
|
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
|
-
from ...loaders import TextualInversionLoaderMixin
|
25
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
27
27
|
from ...schedulers import KarrasDiffusionSchedulers
|
28
28
|
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
|
|
138
138
|
return mask, masked_image
|
139
139
|
|
140
140
|
|
141
|
-
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
141
|
+
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
142
142
|
r"""
|
143
143
|
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
144
144
|
|
145
145
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
146
146
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
147
147
|
|
148
|
+
In addition the pipeline inherits the following loading methods:
|
149
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
150
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
151
|
+
|
152
|
+
as well as the following saving methods:
|
153
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
154
|
+
|
148
155
|
Args:
|
149
156
|
vae ([`AutoencoderKL`]):
|
150
157
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -22,7 +22,7 @@ from packaging import version
|
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
|
-
from ...loaders import TextualInversionLoaderMixin
|
25
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
27
27
|
from ...schedulers import KarrasDiffusionSchedulers
|
28
28
|
from ...utils import (
|
@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
|
|
41
41
|
logger = logging.get_logger(__name__)
|
42
42
|
|
43
43
|
|
44
|
-
def preprocess_image(image):
|
44
|
+
def preprocess_image(image, batch_size):
|
45
45
|
w, h = image.size
|
46
46
|
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
47
47
|
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
48
48
|
image = np.array(image).astype(np.float32) / 255.0
|
49
|
-
image = image[None].transpose(0, 3, 1, 2)
|
49
|
+
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
|
50
50
|
image = torch.from_numpy(image)
|
51
51
|
return 2.0 * image - 1.0
|
52
52
|
|
53
53
|
|
54
|
-
def preprocess_mask(mask, scale_factor=8):
|
54
|
+
def preprocess_mask(mask, batch_size, scale_factor=8):
|
55
55
|
if not isinstance(mask, torch.FloatTensor):
|
56
56
|
mask = mask.convert("L")
|
57
57
|
w, h = mask.size
|
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
|
|
59
59
|
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
60
60
|
mask = np.array(mask).astype(np.float32) / 255.0
|
61
61
|
mask = np.tile(mask, (4, 1, 1))
|
62
|
-
mask = mask[None]
|
62
|
+
mask = np.vstack([mask[None]] * batch_size)
|
63
63
|
mask = 1 - mask # repaint white, keep black
|
64
64
|
mask = torch.from_numpy(mask)
|
65
65
|
return mask
|
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
|
|
82
82
|
return mask
|
83
83
|
|
84
84
|
|
85
|
-
class StableDiffusionInpaintPipelineLegacy(
|
85
|
+
class StableDiffusionInpaintPipelineLegacy(
|
86
|
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
87
|
+
):
|
86
88
|
r"""
|
87
89
|
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
88
90
|
|
89
91
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
90
92
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
91
93
|
|
94
|
+
In addition the pipeline inherits the following loading methods:
|
95
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
96
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
97
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
98
|
+
|
99
|
+
as well as the following saving methods:
|
100
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
101
|
+
|
92
102
|
Args:
|
93
103
|
vae ([`AutoencoderKL`]):
|
94
104
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -511,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
511
521
|
|
512
522
|
return timesteps, num_inference_steps - t_start
|
513
523
|
|
514
|
-
def prepare_latents(self, image, timestep,
|
524
|
+
def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
|
515
525
|
image = image.to(device=self.device, dtype=dtype)
|
516
526
|
init_latent_dist = self.vae.encode(image).latent_dist
|
517
527
|
init_latents = init_latent_dist.sample(generator=generator)
|
518
528
|
init_latents = self.vae.config.scaling_factor * init_latents
|
519
529
|
|
520
530
|
# Expand init_latents for batch_size and num_images_per_prompt
|
521
|
-
init_latents = torch.cat([init_latents] *
|
531
|
+
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
522
532
|
init_latents_orig = init_latents
|
523
533
|
|
524
534
|
# add noise to latents using the timesteps
|
@@ -649,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
649
659
|
|
650
660
|
# 4. Preprocess image and mask
|
651
661
|
if not isinstance(image, torch.FloatTensor):
|
652
|
-
image = preprocess_image(image)
|
662
|
+
image = preprocess_image(image, batch_size)
|
653
663
|
|
654
|
-
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
664
|
+
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
|
655
665
|
|
656
666
|
# 5. set timesteps
|
657
667
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -661,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
661
671
|
# 6. Prepare latent variables
|
662
672
|
# encode the init image into latents and scale the latents
|
663
673
|
latents, init_latents_orig, noise = self.prepare_latents(
|
664
|
-
image, latent_timestep,
|
674
|
+
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
665
675
|
)
|
666
676
|
|
667
677
|
# 7. Prepare mask latent
|
668
678
|
mask = mask_image.to(device=self.device, dtype=latents.dtype)
|
669
|
-
mask = torch.cat([mask] *
|
679
|
+
mask = torch.cat([mask] * num_images_per_prompt)
|
670
680
|
|
671
681
|
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
672
682
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
@@ -20,7 +20,7 @@ import PIL
|
|
20
20
|
import torch
|
21
21
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
22
22
|
|
23
|
-
from ...loaders import TextualInversionLoaderMixin
|
23
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
25
|
from ...schedulers import KarrasDiffusionSchedulers
|
26
26
|
from ...utils import (
|
@@ -61,13 +61,20 @@ def preprocess(image):
|
|
61
61
|
return image
|
62
62
|
|
63
63
|
|
64
|
-
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
64
|
+
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
65
65
|
r"""
|
66
66
|
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
67
67
|
|
68
68
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
69
69
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
70
70
|
|
71
|
+
In addition the pipeline inherits the following loading methods:
|
72
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
73
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
74
|
+
|
75
|
+
as well as the following saving methods:
|
76
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
77
|
+
|
71
78
|
Args:
|
72
79
|
vae ([`AutoencoderKL`]):
|
73
80
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -36,6 +36,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
|
|
36
36
|
from ...utils import (
|
37
37
|
PIL_INTERPOLATION,
|
38
38
|
BaseOutput,
|
39
|
+
deprecate,
|
39
40
|
is_accelerate_available,
|
40
41
|
is_accelerate_version,
|
41
42
|
logging,
|
@@ -721,23 +722,31 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
721
722
|
)
|
722
723
|
|
723
724
|
if isinstance(generator, list):
|
724
|
-
|
725
|
-
|
726
|
-
]
|
727
|
-
init_latents = torch.cat(init_latents, dim=0)
|
725
|
+
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
|
726
|
+
latents = torch.cat(latents, dim=0)
|
728
727
|
else:
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
if batch_size
|
734
|
-
|
735
|
-
|
736
|
-
|
728
|
+
latents = self.vae.encode(image).latent_dist.sample(generator)
|
729
|
+
|
730
|
+
latents = self.vae.config.scaling_factor * latents
|
731
|
+
|
732
|
+
if batch_size != latents.shape[0]:
|
733
|
+
if batch_size % latents.shape[0] == 0:
|
734
|
+
# expand image_latents for batch_size
|
735
|
+
deprecation_message = (
|
736
|
+
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
|
737
|
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
738
|
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
739
|
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
740
|
+
)
|
741
|
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
742
|
+
additional_latents_per_image = batch_size // latents.shape[0]
|
743
|
+
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
|
744
|
+
else:
|
745
|
+
raise ValueError(
|
746
|
+
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
|
747
|
+
)
|
737
748
|
else:
|
738
|
-
|
739
|
-
|
740
|
-
latents = init_latents
|
749
|
+
latents = torch.cat([latents], dim=0)
|
741
750
|
|
742
751
|
return latents
|
743
752
|
|
@@ -759,23 +768,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
759
768
|
)
|
760
769
|
|
761
770
|
def auto_corr_loss(self, hidden_states, generator=None):
|
762
|
-
batch_size, channel, height, width = hidden_states.shape
|
763
|
-
if batch_size > 1:
|
764
|
-
raise ValueError("Only batch_size 1 is supported for now")
|
765
|
-
|
766
|
-
hidden_states = hidden_states.squeeze(0)
|
767
|
-
# hidden_states must be shape [C,H,W] now
|
768
771
|
reg_loss = 0.0
|
769
772
|
for i in range(hidden_states.shape[0]):
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
773
|
+
for j in range(hidden_states.shape[1]):
|
774
|
+
noise = hidden_states[i : i + 1, j : j + 1, :, :]
|
775
|
+
while True:
|
776
|
+
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
|
777
|
+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
|
778
|
+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
|
779
|
+
|
780
|
+
if noise.shape[2] <= 8:
|
781
|
+
break
|
782
|
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
779
783
|
return reg_loss
|
780
784
|
|
781
785
|
def kl_divergence(self, hidden_states):
|
@@ -13,18 +13,20 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import Callable, List, Optional, Union
|
16
|
+
from typing import Any, Callable, List, Optional, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import PIL
|
20
20
|
import torch
|
21
|
-
|
21
|
+
import torch.nn.functional as F
|
22
|
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
22
23
|
|
23
24
|
from ...loaders import TextualInversionLoaderMixin
|
24
25
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
26
|
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
26
|
-
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
27
|
-
from ..pipeline_utils import DiffusionPipeline
|
27
|
+
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
28
|
+
from ..pipeline_utils import DiffusionPipeline
|
29
|
+
from . import StableDiffusionPipelineOutput
|
28
30
|
|
29
31
|
|
30
32
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -76,6 +78,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
76
78
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
77
79
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
78
80
|
"""
|
81
|
+
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
|
79
82
|
|
80
83
|
def __init__(
|
81
84
|
self,
|
@@ -85,12 +88,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
85
88
|
unet: UNet2DConditionModel,
|
86
89
|
low_res_scheduler: DDPMScheduler,
|
87
90
|
scheduler: KarrasDiffusionSchedulers,
|
91
|
+
safety_checker: Optional[Any] = None,
|
92
|
+
feature_extractor: Optional[CLIPImageProcessor] = None,
|
93
|
+
watermarker: Optional[Any] = None,
|
88
94
|
max_noise_level: int = 350,
|
89
95
|
):
|
90
96
|
super().__init__()
|
91
97
|
|
92
|
-
if hasattr(
|
93
|
-
|
98
|
+
if hasattr(
|
99
|
+
vae, "config"
|
100
|
+
): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
94
101
|
is_vae_scaling_factor_set_to_0_08333 = (
|
95
102
|
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
96
103
|
)
|
@@ -113,6 +120,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
113
120
|
unet=unet,
|
114
121
|
low_res_scheduler=low_res_scheduler,
|
115
122
|
scheduler=scheduler,
|
123
|
+
safety_checker=safety_checker,
|
124
|
+
watermarker=watermarker,
|
125
|
+
feature_extractor=feature_extractor,
|
116
126
|
)
|
117
127
|
self.register_to_config(max_noise_level=max_noise_level)
|
118
128
|
|
@@ -129,10 +139,36 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
129
139
|
|
130
140
|
device = torch.device(f"cuda:{gpu_id}")
|
131
141
|
|
132
|
-
for cpu_offloaded_model in [self.unet, self.text_encoder]:
|
142
|
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
133
143
|
if cpu_offloaded_model is not None:
|
134
144
|
cpu_offload(cpu_offloaded_model, device)
|
135
145
|
|
146
|
+
def enable_model_cpu_offload(self, gpu_id=0):
|
147
|
+
r"""
|
148
|
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
149
|
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
150
|
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
151
|
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
152
|
+
"""
|
153
|
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
154
|
+
from accelerate import cpu_offload_with_hook
|
155
|
+
else:
|
156
|
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
157
|
+
|
158
|
+
device = torch.device(f"cuda:{gpu_id}")
|
159
|
+
|
160
|
+
if self.device.type != "cpu":
|
161
|
+
self.to("cpu", silence_dtype_warnings=True)
|
162
|
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
163
|
+
|
164
|
+
hook = None
|
165
|
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
166
|
+
if cpu_offloaded_model is not None:
|
167
|
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
168
|
+
|
169
|
+
# We'll offload the last model manually.
|
170
|
+
self.final_offload_hook = hook
|
171
|
+
|
136
172
|
@property
|
137
173
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
138
174
|
def _execution_device(self):
|
@@ -152,6 +188,23 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
152
188
|
return torch.device(module._hf_hook.execution_device)
|
153
189
|
return self.device
|
154
190
|
|
191
|
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
192
|
+
def run_safety_checker(self, image, device, dtype):
|
193
|
+
if self.safety_checker is not None:
|
194
|
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
195
|
+
image, nsfw_detected, watermark_detected = self.safety_checker(
|
196
|
+
images=image,
|
197
|
+
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
nsfw_detected = None
|
201
|
+
watermark_detected = None
|
202
|
+
|
203
|
+
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
204
|
+
self.unet_offload_hook.offload()
|
205
|
+
|
206
|
+
return image, nsfw_detected, watermark_detected
|
207
|
+
|
155
208
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
156
209
|
def _encode_prompt(
|
157
210
|
self,
|
@@ -645,13 +698,43 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
645
698
|
# 10. Post-processing
|
646
699
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
647
700
|
self.vae.to(dtype=torch.float32)
|
648
|
-
|
701
|
+
|
702
|
+
# TODO(Patrick, William) - clean up when attention is refactored
|
703
|
+
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
|
704
|
+
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
|
705
|
+
# if xformers or torch_2_0 is used attention block does not need
|
706
|
+
# to be in float32 which can save lots of memory
|
707
|
+
if not use_torch_2_0_attn and not use_xformers:
|
708
|
+
self.vae.post_quant_conv.to(latents.dtype)
|
709
|
+
self.vae.decoder.conv_in.to(latents.dtype)
|
710
|
+
self.vae.decoder.mid_block.to(latents.dtype)
|
711
|
+
else:
|
712
|
+
latents = latents.float()
|
649
713
|
|
650
714
|
# 11. Convert to PIL
|
651
715
|
if output_type == "pil":
|
716
|
+
image = self.decode_latents(latents)
|
717
|
+
|
718
|
+
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
719
|
+
|
652
720
|
image = self.numpy_to_pil(image)
|
653
721
|
|
722
|
+
# 11. Apply watermark
|
723
|
+
if self.watermarker is not None:
|
724
|
+
image = self.watermarker.apply_watermark(image)
|
725
|
+
elif output_type == "pt":
|
726
|
+
latents = 1 / self.vae.config.scaling_factor * latents
|
727
|
+
image = self.vae.decode(latents).sample
|
728
|
+
has_nsfw_concept = None
|
729
|
+
else:
|
730
|
+
image = self.decode_latents(latents)
|
731
|
+
has_nsfw_concept = None
|
732
|
+
|
733
|
+
# Offload last model to CPU
|
734
|
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
735
|
+
self.final_offload_hook.offload()
|
736
|
+
|
654
737
|
if not return_dict:
|
655
|
-
return (image,)
|
738
|
+
return (image, has_nsfw_concept)
|
656
739
|
|
657
|
-
return
|
740
|
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
@@ -15,7 +15,7 @@ from ...models.attention_processor import (
|
|
15
15
|
AttnProcessor,
|
16
16
|
)
|
17
17
|
from ...models.dual_transformer_2d import DualTransformer2DModel
|
18
|
-
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
18
|
+
from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
19
19
|
from ...models.transformer_2d import Transformer2DModel
|
20
20
|
from ...models.unet_2d_condition import UNet2DConditionOutput
|
21
21
|
from ...utils import logging
|
@@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
183
183
|
class_embed_type (`str`, *optional*, defaults to None):
|
184
184
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
185
185
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
186
|
+
addition_embed_type (`str`, *optional*, defaults to None):
|
187
|
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
188
|
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
186
189
|
num_class_embeds (`int`, *optional*, defaults to None):
|
187
190
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
188
191
|
class conditioning with `class_embed_type` equal to `None`.
|
189
192
|
time_embedding_type (`str`, *optional*, default to `positional`):
|
190
193
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
194
|
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
195
|
+
An optional override for the dimension of the projected time embedding.
|
191
196
|
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
192
197
|
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
193
198
|
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
@@ -246,12 +251,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
246
251
|
dual_cross_attention: bool = False,
|
247
252
|
use_linear_projection: bool = False,
|
248
253
|
class_embed_type: Optional[str] = None,
|
254
|
+
addition_embed_type: Optional[str] = None,
|
249
255
|
num_class_embeds: Optional[int] = None,
|
250
256
|
upcast_attention: bool = False,
|
251
257
|
resnet_time_scale_shift: str = "default",
|
252
258
|
resnet_skip_time_act: bool = False,
|
253
259
|
resnet_out_scale_factor: int = 1.0,
|
254
260
|
time_embedding_type: str = "positional",
|
261
|
+
time_embedding_dim: Optional[int] = None,
|
255
262
|
time_embedding_act_fn: Optional[str] = None,
|
256
263
|
timestep_post_act: Optional[str] = None,
|
257
264
|
time_cond_proj_dim: Optional[int] = None,
|
@@ -261,6 +268,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
261
268
|
class_embeddings_concat: bool = False,
|
262
269
|
mid_block_only_cross_attention: Optional[bool] = None,
|
263
270
|
cross_attention_norm: Optional[str] = None,
|
271
|
+
addition_embed_type_num_heads=64,
|
264
272
|
):
|
265
273
|
super().__init__()
|
266
274
|
|
@@ -311,7 +319,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
311
319
|
|
312
320
|
# time
|
313
321
|
if time_embedding_type == "fourier":
|
314
|
-
time_embed_dim = block_out_channels[0] * 2
|
322
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
315
323
|
if time_embed_dim % 2 != 0:
|
316
324
|
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
317
325
|
self.time_proj = GaussianFourierProjection(
|
@@ -319,7 +327,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
319
327
|
)
|
320
328
|
timestep_input_dim = time_embed_dim
|
321
329
|
elif time_embedding_type == "positional":
|
322
|
-
time_embed_dim = block_out_channels[0] * 4
|
330
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
323
331
|
|
324
332
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
325
333
|
timestep_input_dim = block_out_channels[0]
|
@@ -345,7 +353,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
345
353
|
if class_embed_type is None and num_class_embeds is not None:
|
346
354
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
347
355
|
elif class_embed_type == "timestep":
|
348
|
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
356
|
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
349
357
|
elif class_embed_type == "identity":
|
350
358
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
351
359
|
elif class_embed_type == "projection":
|
@@ -370,6 +378,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
370
378
|
else:
|
371
379
|
self.class_embedding = None
|
372
380
|
|
381
|
+
if addition_embed_type == "text":
|
382
|
+
if encoder_hid_dim is not None:
|
383
|
+
text_time_embedding_from_dim = encoder_hid_dim
|
384
|
+
else:
|
385
|
+
text_time_embedding_from_dim = cross_attention_dim
|
386
|
+
|
387
|
+
self.add_embedding = TextTimeEmbedding(
|
388
|
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
389
|
+
)
|
390
|
+
elif addition_embed_type is not None:
|
391
|
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
392
|
+
|
373
393
|
if time_embedding_act_fn is None:
|
374
394
|
self.time_embed_act = None
|
375
395
|
elif time_embedding_act_fn == "swish":
|
@@ -534,7 +554,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
534
554
|
self.conv_norm_out = nn.GroupNorm(
|
535
555
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
536
556
|
)
|
537
|
-
|
557
|
+
|
558
|
+
if act_fn == "swish":
|
559
|
+
self.conv_act = lambda x: F.silu(x)
|
560
|
+
elif act_fn == "mish":
|
561
|
+
self.conv_act = nn.Mish()
|
562
|
+
elif act_fn == "silu":
|
563
|
+
self.conv_act = nn.SiLU()
|
564
|
+
elif act_fn == "gelu":
|
565
|
+
self.conv_act = nn.GELU()
|
566
|
+
else:
|
567
|
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
568
|
+
|
538
569
|
else:
|
539
570
|
self.conv_norm_out = None
|
540
571
|
self.conv_act = None
|
@@ -745,7 +776,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
745
776
|
|
746
777
|
t_emb = self.time_proj(timesteps)
|
747
778
|
|
748
|
-
#
|
779
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
749
780
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
750
781
|
# there might be better ways to encapsulate this.
|
751
782
|
t_emb = t_emb.to(dtype=self.dtype)
|
@@ -759,6 +790,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
759
790
|
if self.config.class_embed_type == "timestep":
|
760
791
|
class_labels = self.time_proj(class_labels)
|
761
792
|
|
793
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
794
|
+
# there might be better ways to encapsulate this.
|
795
|
+
class_labels = class_labels.to(dtype=sample.dtype)
|
796
|
+
|
762
797
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
763
798
|
|
764
799
|
if self.config.class_embeddings_concat:
|
@@ -766,6 +801,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
766
801
|
else:
|
767
802
|
emb = emb + class_emb
|
768
803
|
|
804
|
+
if self.config.addition_embed_type == "text":
|
805
|
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
806
|
+
emb = emb + aug_emb
|
807
|
+
|
769
808
|
if self.time_embed_act is not None:
|
770
809
|
emb = self.time_embed_act(emb)
|
771
810
|
|