diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from packaging import version
|
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
|
-
from ...loaders import TextualInversionLoaderMixin
|
25
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
27
27
|
from ...schedulers import KarrasDiffusionSchedulers
|
28
28
|
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
|
|
138
138
|
return mask, masked_image
|
139
139
|
|
140
140
|
|
141
|
-
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
141
|
+
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
142
142
|
r"""
|
143
143
|
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
144
144
|
|
145
145
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
146
146
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
147
147
|
|
148
|
+
In addition the pipeline inherits the following loading methods:
|
149
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
150
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
151
|
+
|
152
|
+
as well as the following saving methods:
|
153
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
154
|
+
|
148
155
|
Args:
|
149
156
|
vae ([`AutoencoderKL`]):
|
150
157
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -22,7 +22,7 @@ from packaging import version
|
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
|
-
from ...loaders import TextualInversionLoaderMixin
|
25
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
27
27
|
from ...schedulers import KarrasDiffusionSchedulers
|
28
28
|
from ...utils import (
|
@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
|
|
41
41
|
logger = logging.get_logger(__name__)
|
42
42
|
|
43
43
|
|
44
|
-
def preprocess_image(image):
|
44
|
+
def preprocess_image(image, batch_size):
|
45
45
|
w, h = image.size
|
46
46
|
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
47
47
|
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
48
48
|
image = np.array(image).astype(np.float32) / 255.0
|
49
|
-
image = image[None].transpose(0, 3, 1, 2)
|
49
|
+
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
|
50
50
|
image = torch.from_numpy(image)
|
51
51
|
return 2.0 * image - 1.0
|
52
52
|
|
53
53
|
|
54
|
-
def preprocess_mask(mask, scale_factor=8):
|
54
|
+
def preprocess_mask(mask, batch_size, scale_factor=8):
|
55
55
|
if not isinstance(mask, torch.FloatTensor):
|
56
56
|
mask = mask.convert("L")
|
57
57
|
w, h = mask.size
|
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
|
|
59
59
|
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
60
60
|
mask = np.array(mask).astype(np.float32) / 255.0
|
61
61
|
mask = np.tile(mask, (4, 1, 1))
|
62
|
-
mask = mask[None]
|
62
|
+
mask = np.vstack([mask[None]] * batch_size)
|
63
63
|
mask = 1 - mask # repaint white, keep black
|
64
64
|
mask = torch.from_numpy(mask)
|
65
65
|
return mask
|
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
|
|
82
82
|
return mask
|
83
83
|
|
84
84
|
|
85
|
-
class StableDiffusionInpaintPipelineLegacy(
|
85
|
+
class StableDiffusionInpaintPipelineLegacy(
|
86
|
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
87
|
+
):
|
86
88
|
r"""
|
87
89
|
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
88
90
|
|
89
91
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
90
92
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
91
93
|
|
94
|
+
In addition the pipeline inherits the following loading methods:
|
95
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
96
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
97
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
98
|
+
|
99
|
+
as well as the following saving methods:
|
100
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
101
|
+
|
92
102
|
Args:
|
93
103
|
vae ([`AutoencoderKL`]):
|
94
104
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -511,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
511
521
|
|
512
522
|
return timesteps, num_inference_steps - t_start
|
513
523
|
|
514
|
-
def prepare_latents(self, image, timestep,
|
524
|
+
def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
|
515
525
|
image = image.to(device=self.device, dtype=dtype)
|
516
526
|
init_latent_dist = self.vae.encode(image).latent_dist
|
517
527
|
init_latents = init_latent_dist.sample(generator=generator)
|
518
528
|
init_latents = self.vae.config.scaling_factor * init_latents
|
519
529
|
|
520
530
|
# Expand init_latents for batch_size and num_images_per_prompt
|
521
|
-
init_latents = torch.cat([init_latents] *
|
531
|
+
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
522
532
|
init_latents_orig = init_latents
|
523
533
|
|
524
534
|
# add noise to latents using the timesteps
|
@@ -649,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
649
659
|
|
650
660
|
# 4. Preprocess image and mask
|
651
661
|
if not isinstance(image, torch.FloatTensor):
|
652
|
-
image = preprocess_image(image)
|
662
|
+
image = preprocess_image(image, batch_size)
|
653
663
|
|
654
|
-
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
664
|
+
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
|
655
665
|
|
656
666
|
# 5. set timesteps
|
657
667
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -661,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
|
661
671
|
# 6. Prepare latent variables
|
662
672
|
# encode the init image into latents and scale the latents
|
663
673
|
latents, init_latents_orig, noise = self.prepare_latents(
|
664
|
-
image, latent_timestep,
|
674
|
+
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
665
675
|
)
|
666
676
|
|
667
677
|
# 7. Prepare mask latent
|
668
678
|
mask = mask_image.to(device=self.device, dtype=latents.dtype)
|
669
|
-
mask = torch.cat([mask] *
|
679
|
+
mask = torch.cat([mask] * num_images_per_prompt)
|
670
680
|
|
671
681
|
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
672
682
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
@@ -20,7 +20,7 @@ import PIL
|
|
20
20
|
import torch
|
21
21
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
22
22
|
|
23
|
-
from ...loaders import TextualInversionLoaderMixin
|
23
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
25
|
from ...schedulers import KarrasDiffusionSchedulers
|
26
26
|
from ...utils import (
|
@@ -61,13 +61,20 @@ def preprocess(image):
|
|
61
61
|
return image
|
62
62
|
|
63
63
|
|
64
|
-
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
64
|
+
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
65
65
|
r"""
|
66
66
|
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
67
67
|
|
68
68
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
69
69
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
70
70
|
|
71
|
+
In addition the pipeline inherits the following loading methods:
|
72
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
73
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
74
|
+
|
75
|
+
as well as the following saving methods:
|
76
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
77
|
+
|
71
78
|
Args:
|
72
79
|
vae ([`AutoencoderKL`]):
|
73
80
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -36,6 +36,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
|
|
36
36
|
from ...utils import (
|
37
37
|
PIL_INTERPOLATION,
|
38
38
|
BaseOutput,
|
39
|
+
deprecate,
|
39
40
|
is_accelerate_available,
|
40
41
|
is_accelerate_version,
|
41
42
|
logging,
|
@@ -721,23 +722,31 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
721
722
|
)
|
722
723
|
|
723
724
|
if isinstance(generator, list):
|
724
|
-
|
725
|
-
|
726
|
-
]
|
727
|
-
init_latents = torch.cat(init_latents, dim=0)
|
725
|
+
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
|
726
|
+
latents = torch.cat(latents, dim=0)
|
728
727
|
else:
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
if batch_size
|
734
|
-
|
735
|
-
|
736
|
-
|
728
|
+
latents = self.vae.encode(image).latent_dist.sample(generator)
|
729
|
+
|
730
|
+
latents = self.vae.config.scaling_factor * latents
|
731
|
+
|
732
|
+
if batch_size != latents.shape[0]:
|
733
|
+
if batch_size % latents.shape[0] == 0:
|
734
|
+
# expand image_latents for batch_size
|
735
|
+
deprecation_message = (
|
736
|
+
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
|
737
|
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
738
|
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
739
|
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
740
|
+
)
|
741
|
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
742
|
+
additional_latents_per_image = batch_size // latents.shape[0]
|
743
|
+
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
|
744
|
+
else:
|
745
|
+
raise ValueError(
|
746
|
+
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
|
747
|
+
)
|
737
748
|
else:
|
738
|
-
|
739
|
-
|
740
|
-
latents = init_latents
|
749
|
+
latents = torch.cat([latents], dim=0)
|
741
750
|
|
742
751
|
return latents
|
743
752
|
|
@@ -759,23 +768,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
759
768
|
)
|
760
769
|
|
761
770
|
def auto_corr_loss(self, hidden_states, generator=None):
|
762
|
-
batch_size, channel, height, width = hidden_states.shape
|
763
|
-
if batch_size > 1:
|
764
|
-
raise ValueError("Only batch_size 1 is supported for now")
|
765
|
-
|
766
|
-
hidden_states = hidden_states.squeeze(0)
|
767
|
-
# hidden_states must be shape [C,H,W] now
|
768
771
|
reg_loss = 0.0
|
769
772
|
for i in range(hidden_states.shape[0]):
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
773
|
+
for j in range(hidden_states.shape[1]):
|
774
|
+
noise = hidden_states[i : i + 1, j : j + 1, :, :]
|
775
|
+
while True:
|
776
|
+
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
|
777
|
+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
|
778
|
+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
|
779
|
+
|
780
|
+
if noise.shape[2] <= 8:
|
781
|
+
break
|
782
|
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
779
783
|
return reg_loss
|
780
784
|
|
781
785
|
def kl_divergence(self, hidden_states):
|
@@ -13,18 +13,20 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import Callable, List, Optional, Union
|
16
|
+
from typing import Any, Callable, List, Optional, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import PIL
|
20
20
|
import torch
|
21
|
-
|
21
|
+
import torch.nn.functional as F
|
22
|
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
22
23
|
|
23
24
|
from ...loaders import TextualInversionLoaderMixin
|
24
25
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
26
|
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
26
|
-
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
27
|
-
from ..pipeline_utils import DiffusionPipeline
|
27
|
+
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
28
|
+
from ..pipeline_utils import DiffusionPipeline
|
29
|
+
from . import StableDiffusionPipelineOutput
|
28
30
|
|
29
31
|
|
30
32
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -76,6 +78,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
76
78
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
77
79
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
78
80
|
"""
|
81
|
+
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
|
79
82
|
|
80
83
|
def __init__(
|
81
84
|
self,
|
@@ -85,12 +88,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
85
88
|
unet: UNet2DConditionModel,
|
86
89
|
low_res_scheduler: DDPMScheduler,
|
87
90
|
scheduler: KarrasDiffusionSchedulers,
|
91
|
+
safety_checker: Optional[Any] = None,
|
92
|
+
feature_extractor: Optional[CLIPImageProcessor] = None,
|
93
|
+
watermarker: Optional[Any] = None,
|
88
94
|
max_noise_level: int = 350,
|
89
95
|
):
|
90
96
|
super().__init__()
|
91
97
|
|
92
|
-
if hasattr(
|
93
|
-
|
98
|
+
if hasattr(
|
99
|
+
vae, "config"
|
100
|
+
): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
94
101
|
is_vae_scaling_factor_set_to_0_08333 = (
|
95
102
|
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
96
103
|
)
|
@@ -113,6 +120,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
113
120
|
unet=unet,
|
114
121
|
low_res_scheduler=low_res_scheduler,
|
115
122
|
scheduler=scheduler,
|
123
|
+
safety_checker=safety_checker,
|
124
|
+
watermarker=watermarker,
|
125
|
+
feature_extractor=feature_extractor,
|
116
126
|
)
|
117
127
|
self.register_to_config(max_noise_level=max_noise_level)
|
118
128
|
|
@@ -129,10 +139,36 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
129
139
|
|
130
140
|
device = torch.device(f"cuda:{gpu_id}")
|
131
141
|
|
132
|
-
for cpu_offloaded_model in [self.unet, self.text_encoder]:
|
142
|
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
133
143
|
if cpu_offloaded_model is not None:
|
134
144
|
cpu_offload(cpu_offloaded_model, device)
|
135
145
|
|
146
|
+
def enable_model_cpu_offload(self, gpu_id=0):
|
147
|
+
r"""
|
148
|
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
149
|
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
150
|
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
151
|
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
152
|
+
"""
|
153
|
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
154
|
+
from accelerate import cpu_offload_with_hook
|
155
|
+
else:
|
156
|
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
157
|
+
|
158
|
+
device = torch.device(f"cuda:{gpu_id}")
|
159
|
+
|
160
|
+
if self.device.type != "cpu":
|
161
|
+
self.to("cpu", silence_dtype_warnings=True)
|
162
|
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
163
|
+
|
164
|
+
hook = None
|
165
|
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
166
|
+
if cpu_offloaded_model is not None:
|
167
|
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
168
|
+
|
169
|
+
# We'll offload the last model manually.
|
170
|
+
self.final_offload_hook = hook
|
171
|
+
|
136
172
|
@property
|
137
173
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
138
174
|
def _execution_device(self):
|
@@ -152,6 +188,23 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
152
188
|
return torch.device(module._hf_hook.execution_device)
|
153
189
|
return self.device
|
154
190
|
|
191
|
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
192
|
+
def run_safety_checker(self, image, device, dtype):
|
193
|
+
if self.safety_checker is not None:
|
194
|
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
195
|
+
image, nsfw_detected, watermark_detected = self.safety_checker(
|
196
|
+
images=image,
|
197
|
+
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
nsfw_detected = None
|
201
|
+
watermark_detected = None
|
202
|
+
|
203
|
+
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
204
|
+
self.unet_offload_hook.offload()
|
205
|
+
|
206
|
+
return image, nsfw_detected, watermark_detected
|
207
|
+
|
155
208
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
156
209
|
def _encode_prompt(
|
157
210
|
self,
|
@@ -645,13 +698,43 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
|
645
698
|
# 10. Post-processing
|
646
699
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
647
700
|
self.vae.to(dtype=torch.float32)
|
648
|
-
|
701
|
+
|
702
|
+
# TODO(Patrick, William) - clean up when attention is refactored
|
703
|
+
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
|
704
|
+
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
|
705
|
+
# if xformers or torch_2_0 is used attention block does not need
|
706
|
+
# to be in float32 which can save lots of memory
|
707
|
+
if not use_torch_2_0_attn and not use_xformers:
|
708
|
+
self.vae.post_quant_conv.to(latents.dtype)
|
709
|
+
self.vae.decoder.conv_in.to(latents.dtype)
|
710
|
+
self.vae.decoder.mid_block.to(latents.dtype)
|
711
|
+
else:
|
712
|
+
latents = latents.float()
|
649
713
|
|
650
714
|
# 11. Convert to PIL
|
651
715
|
if output_type == "pil":
|
716
|
+
image = self.decode_latents(latents)
|
717
|
+
|
718
|
+
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
719
|
+
|
652
720
|
image = self.numpy_to_pil(image)
|
653
721
|
|
722
|
+
# 11. Apply watermark
|
723
|
+
if self.watermarker is not None:
|
724
|
+
image = self.watermarker.apply_watermark(image)
|
725
|
+
elif output_type == "pt":
|
726
|
+
latents = 1 / self.vae.config.scaling_factor * latents
|
727
|
+
image = self.vae.decode(latents).sample
|
728
|
+
has_nsfw_concept = None
|
729
|
+
else:
|
730
|
+
image = self.decode_latents(latents)
|
731
|
+
has_nsfw_concept = None
|
732
|
+
|
733
|
+
# Offload last model to CPU
|
734
|
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
735
|
+
self.final_offload_hook.offload()
|
736
|
+
|
654
737
|
if not return_dict:
|
655
|
-
return (image,)
|
738
|
+
return (image, has_nsfw_concept)
|
656
739
|
|
657
|
-
return
|
740
|
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
@@ -15,7 +15,7 @@ from ...models.attention_processor import (
|
|
15
15
|
AttnProcessor,
|
16
16
|
)
|
17
17
|
from ...models.dual_transformer_2d import DualTransformer2DModel
|
18
|
-
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
18
|
+
from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
19
19
|
from ...models.transformer_2d import Transformer2DModel
|
20
20
|
from ...models.unet_2d_condition import UNet2DConditionOutput
|
21
21
|
from ...utils import logging
|
@@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
183
183
|
class_embed_type (`str`, *optional*, defaults to None):
|
184
184
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
185
185
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
186
|
+
addition_embed_type (`str`, *optional*, defaults to None):
|
187
|
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
188
|
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
186
189
|
num_class_embeds (`int`, *optional*, defaults to None):
|
187
190
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
188
191
|
class conditioning with `class_embed_type` equal to `None`.
|
189
192
|
time_embedding_type (`str`, *optional*, default to `positional`):
|
190
193
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
194
|
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
195
|
+
An optional override for the dimension of the projected time embedding.
|
191
196
|
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
192
197
|
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
193
198
|
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
@@ -246,12 +251,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
246
251
|
dual_cross_attention: bool = False,
|
247
252
|
use_linear_projection: bool = False,
|
248
253
|
class_embed_type: Optional[str] = None,
|
254
|
+
addition_embed_type: Optional[str] = None,
|
249
255
|
num_class_embeds: Optional[int] = None,
|
250
256
|
upcast_attention: bool = False,
|
251
257
|
resnet_time_scale_shift: str = "default",
|
252
258
|
resnet_skip_time_act: bool = False,
|
253
259
|
resnet_out_scale_factor: int = 1.0,
|
254
260
|
time_embedding_type: str = "positional",
|
261
|
+
time_embedding_dim: Optional[int] = None,
|
255
262
|
time_embedding_act_fn: Optional[str] = None,
|
256
263
|
timestep_post_act: Optional[str] = None,
|
257
264
|
time_cond_proj_dim: Optional[int] = None,
|
@@ -261,6 +268,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
261
268
|
class_embeddings_concat: bool = False,
|
262
269
|
mid_block_only_cross_attention: Optional[bool] = None,
|
263
270
|
cross_attention_norm: Optional[str] = None,
|
271
|
+
addition_embed_type_num_heads=64,
|
264
272
|
):
|
265
273
|
super().__init__()
|
266
274
|
|
@@ -311,7 +319,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
311
319
|
|
312
320
|
# time
|
313
321
|
if time_embedding_type == "fourier":
|
314
|
-
time_embed_dim = block_out_channels[0] * 2
|
322
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
315
323
|
if time_embed_dim % 2 != 0:
|
316
324
|
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
317
325
|
self.time_proj = GaussianFourierProjection(
|
@@ -319,7 +327,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
319
327
|
)
|
320
328
|
timestep_input_dim = time_embed_dim
|
321
329
|
elif time_embedding_type == "positional":
|
322
|
-
time_embed_dim = block_out_channels[0] * 4
|
330
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
323
331
|
|
324
332
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
325
333
|
timestep_input_dim = block_out_channels[0]
|
@@ -345,7 +353,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
345
353
|
if class_embed_type is None and num_class_embeds is not None:
|
346
354
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
347
355
|
elif class_embed_type == "timestep":
|
348
|
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
356
|
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
349
357
|
elif class_embed_type == "identity":
|
350
358
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
351
359
|
elif class_embed_type == "projection":
|
@@ -370,6 +378,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
370
378
|
else:
|
371
379
|
self.class_embedding = None
|
372
380
|
|
381
|
+
if addition_embed_type == "text":
|
382
|
+
if encoder_hid_dim is not None:
|
383
|
+
text_time_embedding_from_dim = encoder_hid_dim
|
384
|
+
else:
|
385
|
+
text_time_embedding_from_dim = cross_attention_dim
|
386
|
+
|
387
|
+
self.add_embedding = TextTimeEmbedding(
|
388
|
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
389
|
+
)
|
390
|
+
elif addition_embed_type is not None:
|
391
|
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
392
|
+
|
373
393
|
if time_embedding_act_fn is None:
|
374
394
|
self.time_embed_act = None
|
375
395
|
elif time_embedding_act_fn == "swish":
|
@@ -534,7 +554,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
534
554
|
self.conv_norm_out = nn.GroupNorm(
|
535
555
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
536
556
|
)
|
537
|
-
|
557
|
+
|
558
|
+
if act_fn == "swish":
|
559
|
+
self.conv_act = lambda x: F.silu(x)
|
560
|
+
elif act_fn == "mish":
|
561
|
+
self.conv_act = nn.Mish()
|
562
|
+
elif act_fn == "silu":
|
563
|
+
self.conv_act = nn.SiLU()
|
564
|
+
elif act_fn == "gelu":
|
565
|
+
self.conv_act = nn.GELU()
|
566
|
+
else:
|
567
|
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
568
|
+
|
538
569
|
else:
|
539
570
|
self.conv_norm_out = None
|
540
571
|
self.conv_act = None
|
@@ -745,7 +776,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
745
776
|
|
746
777
|
t_emb = self.time_proj(timesteps)
|
747
778
|
|
748
|
-
#
|
779
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
749
780
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
750
781
|
# there might be better ways to encapsulate this.
|
751
782
|
t_emb = t_emb.to(dtype=self.dtype)
|
@@ -759,6 +790,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
759
790
|
if self.config.class_embed_type == "timestep":
|
760
791
|
class_labels = self.time_proj(class_labels)
|
761
792
|
|
793
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
794
|
+
# there might be better ways to encapsulate this.
|
795
|
+
class_labels = class_labels.to(dtype=sample.dtype)
|
796
|
+
|
762
797
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
763
798
|
|
764
799
|
if self.config.class_embeddings_concat:
|
@@ -766,6 +801,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
766
801
|
else:
|
767
802
|
emb = emb + class_emb
|
768
803
|
|
804
|
+
if self.config.addition_embed_type == "text":
|
805
|
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
806
|
+
emb = emb + aug_emb
|
807
|
+
|
769
808
|
if self.time_embed_act is not None:
|
770
809
|
emb = self.time_embed_act(emb)
|
771
810
|
|