diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -477,8 +477,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|
477
477
|
image = super_res_latents
|
478
478
|
# done super res
|
479
479
|
|
480
|
-
|
480
|
+
self.maybe_free_model_hooks()
|
481
481
|
|
482
|
+
# post processing
|
482
483
|
image = image * 0.5 + 0.5
|
483
484
|
image = image.clamp(0, 1)
|
484
485
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
@@ -19,8 +19,8 @@ import torch
|
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
|
22
23
|
from ...models.modeling_utils import ModelMixin
|
23
|
-
from ...models.vae import DecoderOutput, VectorQuantizer
|
24
24
|
from ...models.vq_model import VQEncoderOutput
|
25
25
|
from ...utils.accelerate_utils import apply_forward_hook
|
26
26
|
|
@@ -17,6 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from ...models.attention_processor import Attention
|
20
|
+
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
21
|
+
from ...utils import USE_PEFT_BACKEND
|
20
22
|
|
21
23
|
|
22
24
|
class WuerstchenLayerNorm(nn.LayerNorm):
|
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
|
|
32
34
|
class TimestepBlock(nn.Module):
|
33
35
|
def __init__(self, c, c_timestep):
|
34
36
|
super().__init__()
|
35
|
-
|
37
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
38
|
+
self.mapper = linear_cls(c_timestep, c * 2)
|
36
39
|
|
37
40
|
def forward(self, x, t):
|
38
41
|
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
|
|
42
45
|
class ResBlock(nn.Module):
|
43
46
|
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
44
47
|
super().__init__()
|
45
|
-
|
48
|
+
|
49
|
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
50
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
51
|
+
|
52
|
+
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
46
53
|
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
47
54
|
self.channelwise = nn.Sequential(
|
48
|
-
|
55
|
+
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
|
49
56
|
)
|
50
57
|
|
51
58
|
def forward(self, x, x_skip=None):
|
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
|
|
73
80
|
class AttnBlock(nn.Module):
|
74
81
|
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
75
82
|
super().__init__()
|
83
|
+
|
84
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
85
|
+
|
76
86
|
self.self_attn = self_attn
|
77
87
|
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
78
88
|
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
79
|
-
self.kv_mapper = nn.Sequential(nn.SiLU(),
|
89
|
+
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
80
90
|
|
81
91
|
def forward(self, x, kv):
|
82
92
|
kv = self.kv_mapper(kv)
|
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
|
|
28
28
|
AttnAddedKVProcessor,
|
29
29
|
AttnProcessor,
|
30
30
|
)
|
31
|
+
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
31
32
|
from ...models.modeling_utils import ModelMixin
|
32
|
-
from ...utils import is_torch_version
|
33
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version
|
33
34
|
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
34
35
|
|
35
36
|
|
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
40
41
|
@register_to_config
|
41
42
|
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
42
43
|
super().__init__()
|
44
|
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
45
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
46
|
+
|
43
47
|
self.c_r = c_r
|
44
|
-
self.projection =
|
48
|
+
self.projection = conv_cls(c_in, c, kernel_size=1)
|
45
49
|
self.cond_mapper = nn.Sequential(
|
46
|
-
|
50
|
+
linear_cls(c_cond, c),
|
47
51
|
nn.LeakyReLU(0.2),
|
48
|
-
|
52
|
+
linear_cls(c, c),
|
49
53
|
)
|
50
54
|
|
51
55
|
self.blocks = nn.ModuleList()
|
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
55
59
|
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
|
56
60
|
self.out = nn.Sequential(
|
57
61
|
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
|
58
|
-
|
62
|
+
conv_cls(c, c_in * 2, kernel_size=1),
|
59
63
|
)
|
60
64
|
|
61
65
|
self.gradient_checkpointing = False
|
@@ -269,7 +269,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
269
269
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
270
270
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
271
271
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
272
|
-
`._callback_tensor_inputs` attribute of your
|
272
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
273
273
|
|
274
274
|
Examples:
|
275
275
|
|
@@ -234,7 +234,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
234
234
|
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
235
235
|
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
236
236
|
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
237
|
-
the `._callback_tensor_inputs` attribute of your
|
237
|
+
the `._callback_tensor_inputs` attribute of your pipeline class.
|
238
238
|
callback_on_step_end (`Callable`, *optional*):
|
239
239
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
240
240
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
@@ -243,7 +243,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
243
243
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
244
244
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
245
245
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
246
|
-
`._callback_tensor_inputs` attribute of your
|
246
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
247
247
|
|
248
248
|
Examples:
|
249
249
|
|
@@ -69,6 +69,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
69
69
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
70
70
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
71
71
|
|
72
|
+
The pipeline also inherits the following loading methods:
|
73
|
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
74
|
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
75
|
+
|
72
76
|
Args:
|
73
77
|
prior ([`Prior`]):
|
74
78
|
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
@@ -349,7 +353,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
349
353
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
350
354
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
351
355
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
352
|
-
`._callback_tensor_inputs` attribute of your
|
356
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
353
357
|
|
354
358
|
Examples:
|
355
359
|
|
diffusers/schedulers/__init__.py
CHANGED
@@ -38,6 +38,8 @@ except OptionalDependencyNotAvailable:
|
|
38
38
|
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
|
39
39
|
|
40
40
|
else:
|
41
|
+
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
42
|
+
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
|
41
43
|
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
42
44
|
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
43
45
|
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
@@ -56,12 +58,10 @@ else:
|
|
56
58
|
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
|
57
59
|
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
58
60
|
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
|
59
|
-
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
|
60
61
|
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
61
62
|
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
62
63
|
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
63
64
|
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
64
|
-
_import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
|
65
65
|
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
66
66
|
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
|
67
67
|
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
|
@@ -129,6 +129,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
129
129
|
except OptionalDependencyNotAvailable:
|
130
130
|
from ..utils.dummy_pt_objects import * # noqa F403
|
131
131
|
else:
|
132
|
+
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
133
|
+
from .scheduling_amused import AmusedScheduler
|
132
134
|
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
133
135
|
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
134
136
|
from .scheduling_ddim import DDIMScheduler
|
@@ -147,12 +149,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
147
149
|
from .scheduling_ipndm import IPNDMScheduler
|
148
150
|
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
149
151
|
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
|
150
|
-
from .scheduling_karras_ve import KarrasVeScheduler
|
151
152
|
from .scheduling_lcm import LCMScheduler
|
152
153
|
from .scheduling_pndm import PNDMScheduler
|
153
154
|
from .scheduling_repaint import RePaintScheduler
|
154
155
|
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
155
|
-
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
156
156
|
from .scheduling_unclip import UnCLIPScheduler
|
157
157
|
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
158
158
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_torch_available,
|
9
|
+
is_transformers_available,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
_dummy_objects = {}
|
14
|
+
_import_structure = {}
|
15
|
+
|
16
|
+
try:
|
17
|
+
if not (is_transformers_available() and is_torch_available()):
|
18
|
+
raise OptionalDependencyNotAvailable()
|
19
|
+
except OptionalDependencyNotAvailable:
|
20
|
+
from ...utils import dummy_pt_objects # noqa F403
|
21
|
+
|
22
|
+
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
23
|
+
else:
|
24
|
+
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
|
25
|
+
_import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
|
26
|
+
|
27
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
28
|
+
try:
|
29
|
+
if not is_torch_available():
|
30
|
+
raise OptionalDependencyNotAvailable()
|
31
|
+
|
32
|
+
except OptionalDependencyNotAvailable:
|
33
|
+
from ..utils.dummy_pt_objects import * # noqa F403
|
34
|
+
else:
|
35
|
+
from .scheduling_karras_ve import KarrasVeScheduler
|
36
|
+
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
37
|
+
|
38
|
+
|
39
|
+
else:
|
40
|
+
import sys
|
41
|
+
|
42
|
+
sys.modules[__name__] = _LazyModule(
|
43
|
+
__name__,
|
44
|
+
globals()["__file__"],
|
45
|
+
_import_structure,
|
46
|
+
module_spec=__spec__,
|
47
|
+
)
|
48
|
+
|
49
|
+
for name, value in _dummy_objects.items():
|
50
|
+
setattr(sys.modules[__name__], name, value)
|
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
21
21
|
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils import BaseOutput
|
24
|
+
from ...utils.torch_utils import randn_tensor
|
25
|
+
from ..scheduling_utils import SchedulerMixin
|
26
26
|
|
27
27
|
|
28
28
|
@dataclass
|
@@ -19,9 +19,9 @@ from typing import Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils.torch_utils import randn_tensor
|
24
|
+
from ..scheduling_utils import SchedulerMixin
|
25
25
|
|
26
26
|
|
27
27
|
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
@@ -79,9 +79,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
|
79
79
|
|
80
80
|
# TODO(Patrick) better comments + non-PyTorch
|
81
81
|
# postprocess model score
|
82
|
-
log_mean_coeff = (
|
83
|
-
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
84
|
-
)
|
82
|
+
log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
85
83
|
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
|
86
84
|
std = std.flatten()
|
87
85
|
while len(std.shape) < len(score.shape):
|
@@ -0,0 +1,162 @@
|
|
1
|
+
import math
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
8
|
+
from ..utils import BaseOutput
|
9
|
+
from .scheduling_utils import SchedulerMixin
|
10
|
+
|
11
|
+
|
12
|
+
def gumbel_noise(t, generator=None):
|
13
|
+
device = generator.device if generator is not None else t.device
|
14
|
+
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
15
|
+
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
16
|
+
|
17
|
+
|
18
|
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
19
|
+
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
20
|
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
21
|
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
22
|
+
masking = confidence < cut_off
|
23
|
+
return masking
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class AmusedSchedulerOutput(BaseOutput):
|
28
|
+
"""
|
29
|
+
Output class for the scheduler's `step` function output.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
33
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
34
|
+
denoising loop.
|
35
|
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36
|
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
37
|
+
`pred_original_sample` can be used to preview progress or for guidance.
|
38
|
+
"""
|
39
|
+
|
40
|
+
prev_sample: torch.FloatTensor
|
41
|
+
pred_original_sample: torch.FloatTensor = None
|
42
|
+
|
43
|
+
|
44
|
+
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
45
|
+
order = 1
|
46
|
+
|
47
|
+
temperatures: torch.Tensor
|
48
|
+
|
49
|
+
@register_to_config
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
mask_token_id: int,
|
53
|
+
masking_schedule: str = "cosine",
|
54
|
+
):
|
55
|
+
self.temperatures = None
|
56
|
+
self.timesteps = None
|
57
|
+
|
58
|
+
def set_timesteps(
|
59
|
+
self,
|
60
|
+
num_inference_steps: int,
|
61
|
+
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
62
|
+
device: Union[str, torch.device] = None,
|
63
|
+
):
|
64
|
+
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
65
|
+
|
66
|
+
if isinstance(temperature, (tuple, list)):
|
67
|
+
self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
|
68
|
+
else:
|
69
|
+
self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
|
70
|
+
|
71
|
+
def step(
|
72
|
+
self,
|
73
|
+
model_output: torch.FloatTensor,
|
74
|
+
timestep: torch.long,
|
75
|
+
sample: torch.LongTensor,
|
76
|
+
starting_mask_ratio: int = 1,
|
77
|
+
generator: Optional[torch.Generator] = None,
|
78
|
+
return_dict: bool = True,
|
79
|
+
) -> Union[AmusedSchedulerOutput, Tuple]:
|
80
|
+
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
81
|
+
|
82
|
+
if two_dim_input:
|
83
|
+
batch_size, codebook_size, height, width = model_output.shape
|
84
|
+
sample = sample.reshape(batch_size, height * width)
|
85
|
+
model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
|
86
|
+
|
87
|
+
unknown_map = sample == self.config.mask_token_id
|
88
|
+
|
89
|
+
probs = model_output.softmax(dim=-1)
|
90
|
+
|
91
|
+
device = probs.device
|
92
|
+
probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
|
93
|
+
if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
|
94
|
+
probs_ = probs_.float() # multinomial is not implemented for cpu half precision
|
95
|
+
probs_ = probs_.reshape(-1, probs.size(-1))
|
96
|
+
pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
|
97
|
+
pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
|
98
|
+
pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
|
99
|
+
|
100
|
+
if timestep == 0:
|
101
|
+
prev_sample = pred_original_sample
|
102
|
+
else:
|
103
|
+
seq_len = sample.shape[1]
|
104
|
+
step_idx = (self.timesteps == timestep).nonzero()
|
105
|
+
ratio = (step_idx + 1) / len(self.timesteps)
|
106
|
+
|
107
|
+
if self.config.masking_schedule == "cosine":
|
108
|
+
mask_ratio = torch.cos(ratio * math.pi / 2)
|
109
|
+
elif self.config.masking_schedule == "linear":
|
110
|
+
mask_ratio = 1 - ratio
|
111
|
+
else:
|
112
|
+
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
113
|
+
|
114
|
+
mask_ratio = starting_mask_ratio * mask_ratio
|
115
|
+
|
116
|
+
mask_len = (seq_len * mask_ratio).floor()
|
117
|
+
# do not mask more than amount previously masked
|
118
|
+
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
119
|
+
# mask at least one
|
120
|
+
mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
|
121
|
+
|
122
|
+
selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
|
123
|
+
# Ignores the tokens given in the input by overwriting their confidence.
|
124
|
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
125
|
+
|
126
|
+
masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
|
127
|
+
|
128
|
+
# Masks tokens with lower confidence.
|
129
|
+
prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
|
130
|
+
|
131
|
+
if two_dim_input:
|
132
|
+
prev_sample = prev_sample.reshape(batch_size, height, width)
|
133
|
+
pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
|
134
|
+
|
135
|
+
if not return_dict:
|
136
|
+
return (prev_sample, pred_original_sample)
|
137
|
+
|
138
|
+
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
|
139
|
+
|
140
|
+
def add_noise(self, sample, timesteps, generator=None):
|
141
|
+
step_idx = (self.timesteps == timesteps).nonzero()
|
142
|
+
ratio = (step_idx + 1) / len(self.timesteps)
|
143
|
+
|
144
|
+
if self.config.masking_schedule == "cosine":
|
145
|
+
mask_ratio = torch.cos(ratio * math.pi / 2)
|
146
|
+
elif self.config.masking_schedule == "linear":
|
147
|
+
mask_ratio = 1 - ratio
|
148
|
+
else:
|
149
|
+
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
150
|
+
|
151
|
+
mask_indices = (
|
152
|
+
torch.rand(
|
153
|
+
sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
|
154
|
+
).to(sample.device)
|
155
|
+
< mask_ratio
|
156
|
+
)
|
157
|
+
|
158
|
+
masked_sample = sample.clone()
|
159
|
+
|
160
|
+
masked_sample[mask_indices] = self.config.mask_token_id
|
161
|
+
|
162
|
+
return masked_sample
|
@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
98
98
|
self.custom_timesteps = False
|
99
99
|
self.is_scale_input_called = False
|
100
100
|
self._step_index = None
|
101
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
101
102
|
|
102
103
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
103
104
|
if schedule_timesteps is None:
|
@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
230
231
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
231
232
|
|
232
233
|
self._step_index = None
|
234
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
233
235
|
|
234
236
|
# Modified _convert_to_karras implementation that takes in ramp as argument
|
235
237
|
def _convert_to_karras(self, ramp):
|
@@ -208,9 +208,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
208
208
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
209
209
|
elif beta_schedule == "scaled_linear":
|
210
210
|
# this schedule is very specific to the latent diffusion model.
|
211
|
-
self.betas = (
|
212
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
213
|
-
)
|
211
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
214
212
|
elif beta_schedule == "squaredcos_cap_v2":
|
215
213
|
# Glide cosine schedule
|
216
214
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -204,9 +204,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
204
204
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
205
205
|
elif beta_schedule == "scaled_linear":
|
206
206
|
# this schedule is very specific to the latent diffusion model.
|
207
|
-
self.betas = (
|
208
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
209
|
-
)
|
207
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
210
208
|
elif beta_schedule == "squaredcos_cap_v2":
|
211
209
|
# Glide cosine schedule
|
212
210
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -295,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
295
293
|
model_output: torch.FloatTensor,
|
296
294
|
timestep: int,
|
297
295
|
sample: torch.FloatTensor,
|
298
|
-
eta: float = 0.0,
|
299
|
-
use_clipped_model_output: bool = False,
|
300
|
-
variance_noise: Optional[torch.FloatTensor] = None,
|
301
296
|
return_dict: bool = True,
|
302
297
|
) -> Union[DDIMSchedulerOutput, Tuple]:
|
303
298
|
"""
|
@@ -334,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
334
329
|
# 1. get previous step value (=t+1)
|
335
330
|
prev_timestep = timestep
|
336
331
|
timestep = min(
|
337
|
-
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
|
332
|
+
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
|
338
333
|
)
|
339
334
|
|
340
335
|
# 2. compute alphas, betas
|
@@ -215,9 +215,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
215
215
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
216
216
|
elif beta_schedule == "scaled_linear":
|
217
217
|
# this schedule is very specific to the latent diffusion model.
|
218
|
-
self.betas = (
|
219
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
220
|
-
)
|
218
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
221
219
|
elif beta_schedule == "squaredcos_cap_v2":
|
222
220
|
# Glide cosine schedule
|
223
221
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -89,6 +89,43 @@ def betas_for_alpha_bar(
|
|
89
89
|
return torch.tensor(betas, dtype=torch.float32)
|
90
90
|
|
91
91
|
|
92
|
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
93
|
+
def rescale_zero_terminal_snr(betas):
|
94
|
+
"""
|
95
|
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
96
|
+
|
97
|
+
|
98
|
+
Args:
|
99
|
+
betas (`torch.FloatTensor`):
|
100
|
+
the betas that the scheduler is being initialized with.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
104
|
+
"""
|
105
|
+
# Convert betas to alphas_bar_sqrt
|
106
|
+
alphas = 1.0 - betas
|
107
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
108
|
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
109
|
+
|
110
|
+
# Store old values.
|
111
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
112
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
113
|
+
|
114
|
+
# Shift so the last timestep is zero.
|
115
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
116
|
+
|
117
|
+
# Scale so the first timestep is back to the old value.
|
118
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
119
|
+
|
120
|
+
# Convert alphas_bar_sqrt to betas
|
121
|
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
122
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
123
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
124
|
+
betas = 1 - alphas
|
125
|
+
|
126
|
+
return betas
|
127
|
+
|
128
|
+
|
92
129
|
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
93
130
|
"""
|
94
131
|
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
|
@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
131
168
|
An offset added to the inference steps. You can use a combination of `offset=1` and
|
132
169
|
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
133
170
|
Diffusion.
|
171
|
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
172
|
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
173
|
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
174
|
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
134
175
|
"""
|
135
176
|
|
136
177
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
153
194
|
sample_max_value: float = 1.0,
|
154
195
|
timestep_spacing: str = "leading",
|
155
196
|
steps_offset: int = 0,
|
197
|
+
rescale_betas_zero_snr: int = False,
|
156
198
|
):
|
157
199
|
if trained_betas is not None:
|
158
200
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -160,9 +202,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
160
202
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
161
203
|
elif beta_schedule == "scaled_linear":
|
162
204
|
# this schedule is very specific to the latent diffusion model.
|
163
|
-
self.betas = (
|
164
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
165
|
-
)
|
205
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
166
206
|
elif beta_schedule == "squaredcos_cap_v2":
|
167
207
|
# Glide cosine schedule
|
168
208
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -173,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
173
213
|
else:
|
174
214
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
175
215
|
|
216
|
+
# Rescale for zero SNR
|
217
|
+
if rescale_betas_zero_snr:
|
218
|
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
219
|
+
|
176
220
|
self.alphas = 1.0 - self.betas
|
177
221
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
178
222
|
self.one = torch.tensor(1.0)
|