diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +3 -1
- diffusers/commands/fp16_safetensors.py +2 -7
- diffusers/configuration_utils.py +23 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/loaders.py +62 -64
- diffusers/models/__init__.py +1 -0
- diffusers/models/activations.py +2 -0
- diffusers/models/attention.py +45 -1
- diffusers/models/autoencoder_tiny.py +193 -0
- diffusers/models/controlnet.py +1 -1
- diffusers/models/embeddings.py +56 -0
- diffusers/models/lora.py +0 -6
- diffusers/models/modeling_flax_utils.py +28 -2
- diffusers/models/modeling_utils.py +33 -16
- diffusers/models/transformer_2d.py +26 -9
- diffusers/models/unet_1d.py +2 -2
- diffusers/models/unet_2d_blocks.py +106 -56
- diffusers/models/unet_2d_condition.py +20 -5
- diffusers/models/vae.py +106 -1
- diffusers/pipelines/__init__.py +1 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/auto_pipeline.py +33 -43
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/pipeline_flax_utils.py +41 -4
- diffusers/pipelines/pipeline_utils.py +60 -16
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
- diffusers/schedulers/scheduling_consistency_models.py +70 -57
- diffusers/schedulers/scheduling_ddim.py +76 -71
- diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
- diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
- diffusers/schedulers/scheduling_ddpm.py +68 -67
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
- diffusers/schedulers/scheduling_deis_multistep.py +93 -85
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
- diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
- diffusers/schedulers/scheduling_euler_discrete.py +63 -56
- diffusers/schedulers/scheduling_heun_discrete.py +57 -45
- diffusers/schedulers/scheduling_ipndm.py +27 -22
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
- diffusers/schedulers/scheduling_karras_ve.py +55 -45
- diffusers/schedulers/scheduling_lms_discrete.py +58 -52
- diffusers/schedulers/scheduling_pndm.py +77 -62
- diffusers/schedulers/scheduling_repaint.py +56 -38
- diffusers/schedulers/scheduling_sde_ve.py +62 -50
- diffusers/schedulers/scheduling_sde_vp.py +32 -11
- diffusers/schedulers/scheduling_unclip.py +3 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
- diffusers/schedulers/scheduling_utils.py +41 -35
- diffusers/schedulers/scheduling_utils_flax.py +8 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- diffusers/utils/hub_utils.py +105 -2
- diffusers/utils/import_utils.py +0 -4
- diffusers/utils/pil_utils.py +19 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
- diffusers/models/cross_attention.py +0 -94
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
21
|
from ..utils import is_torch_version, logging
|
22
|
+
from .activations import get_activation
|
22
23
|
from .attention import AdaGroupNorm
|
23
24
|
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
24
25
|
from .dual_transformer_2d import DualTransformer2DModel
|
@@ -48,6 +49,7 @@ def get_down_block(
|
|
48
49
|
only_cross_attention=False,
|
49
50
|
upcast_attention=False,
|
50
51
|
resnet_time_scale_shift="default",
|
52
|
+
attention_type="default",
|
51
53
|
resnet_skip_time_act=False,
|
52
54
|
resnet_out_scale_factor=1.0,
|
53
55
|
cross_attention_norm=None,
|
@@ -128,6 +130,7 @@ def get_down_block(
|
|
128
130
|
only_cross_attention=only_cross_attention,
|
129
131
|
upcast_attention=upcast_attention,
|
130
132
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
133
|
+
attention_type=attention_type,
|
131
134
|
)
|
132
135
|
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
133
136
|
if cross_attention_dim is None:
|
@@ -243,6 +246,7 @@ def get_up_block(
|
|
243
246
|
only_cross_attention=False,
|
244
247
|
upcast_attention=False,
|
245
248
|
resnet_time_scale_shift="default",
|
249
|
+
attention_type="default",
|
246
250
|
resnet_skip_time_act=False,
|
247
251
|
resnet_out_scale_factor=1.0,
|
248
252
|
cross_attention_norm=None,
|
@@ -306,6 +310,7 @@ def get_up_block(
|
|
306
310
|
only_cross_attention=only_cross_attention,
|
307
311
|
upcast_attention=upcast_attention,
|
308
312
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
313
|
+
attention_type=attention_type,
|
309
314
|
)
|
310
315
|
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
311
316
|
if cross_attention_dim is None:
|
@@ -423,6 +428,28 @@ def get_up_block(
|
|
423
428
|
raise ValueError(f"{up_block_type} does not exist.")
|
424
429
|
|
425
430
|
|
431
|
+
class AutoencoderTinyBlock(nn.Module):
|
432
|
+
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
|
433
|
+
super().__init__()
|
434
|
+
act_fn = get_activation(act_fn)
|
435
|
+
self.conv = nn.Sequential(
|
436
|
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
437
|
+
act_fn,
|
438
|
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
439
|
+
act_fn,
|
440
|
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
441
|
+
)
|
442
|
+
self.skip = (
|
443
|
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
444
|
+
if in_channels != out_channels
|
445
|
+
else nn.Identity()
|
446
|
+
)
|
447
|
+
self.fuse = nn.ReLU()
|
448
|
+
|
449
|
+
def forward(self, x):
|
450
|
+
return self.fuse(self.conv(x) + self.skip(x))
|
451
|
+
|
452
|
+
|
426
453
|
class UNetMidBlock2D(nn.Module):
|
427
454
|
def __init__(
|
428
455
|
self,
|
@@ -533,6 +560,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
533
560
|
dual_cross_attention=False,
|
534
561
|
use_linear_projection=False,
|
535
562
|
upcast_attention=False,
|
563
|
+
attention_type="default",
|
536
564
|
):
|
537
565
|
super().__init__()
|
538
566
|
|
@@ -569,6 +597,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
569
597
|
norm_num_groups=resnet_groups,
|
570
598
|
use_linear_projection=use_linear_projection,
|
571
599
|
upcast_attention=upcast_attention,
|
600
|
+
attention_type=attention_type,
|
572
601
|
)
|
573
602
|
)
|
574
603
|
else:
|
@@ -600,6 +629,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
600
629
|
self.attentions = nn.ModuleList(attentions)
|
601
630
|
self.resnets = nn.ModuleList(resnets)
|
602
631
|
|
632
|
+
self.gradient_checkpointing = False
|
633
|
+
|
603
634
|
def forward(
|
604
635
|
self,
|
605
636
|
hidden_states: torch.FloatTensor,
|
@@ -611,15 +642,42 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
611
642
|
) -> torch.FloatTensor:
|
612
643
|
hidden_states = self.resnets[0](hidden_states, temb)
|
613
644
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
645
|
+
if self.training and self.gradient_checkpointing:
|
646
|
+
|
647
|
+
def create_custom_forward(module, return_dict=None):
|
648
|
+
def custom_forward(*inputs):
|
649
|
+
if return_dict is not None:
|
650
|
+
return module(*inputs, return_dict=return_dict)
|
651
|
+
else:
|
652
|
+
return module(*inputs)
|
653
|
+
|
654
|
+
return custom_forward
|
655
|
+
|
656
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
657
|
+
hidden_states = attn(
|
658
|
+
hidden_states,
|
659
|
+
encoder_hidden_states=encoder_hidden_states,
|
660
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
661
|
+
attention_mask=attention_mask,
|
662
|
+
encoder_attention_mask=encoder_attention_mask,
|
663
|
+
return_dict=False,
|
664
|
+
)[0]
|
665
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
666
|
+
create_custom_forward(resnet),
|
667
|
+
hidden_states,
|
668
|
+
temb,
|
669
|
+
**ckpt_kwargs,
|
670
|
+
)
|
671
|
+
else:
|
672
|
+
hidden_states = attn(
|
673
|
+
hidden_states,
|
674
|
+
encoder_hidden_states=encoder_hidden_states,
|
675
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
676
|
+
attention_mask=attention_mask,
|
677
|
+
encoder_attention_mask=encoder_attention_mask,
|
678
|
+
return_dict=False,
|
679
|
+
)[0]
|
680
|
+
hidden_states = resnet(hidden_states, temb)
|
623
681
|
|
624
682
|
return hidden_states
|
625
683
|
|
@@ -882,6 +940,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
882
940
|
use_linear_projection=False,
|
883
941
|
only_cross_attention=False,
|
884
942
|
upcast_attention=False,
|
943
|
+
attention_type="default",
|
885
944
|
):
|
886
945
|
super().__init__()
|
887
946
|
resnets = []
|
@@ -918,6 +977,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
918
977
|
use_linear_projection=use_linear_projection,
|
919
978
|
only_cross_attention=only_cross_attention,
|
920
979
|
upcast_attention=upcast_attention,
|
980
|
+
attention_type=attention_type,
|
921
981
|
)
|
922
982
|
)
|
923
983
|
else:
|
@@ -980,16 +1040,13 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
980
1040
|
temb,
|
981
1041
|
**ckpt_kwargs,
|
982
1042
|
)
|
983
|
-
hidden_states =
|
984
|
-
create_custom_forward(attn, return_dict=False),
|
1043
|
+
hidden_states = attn(
|
985
1044
|
hidden_states,
|
986
|
-
encoder_hidden_states,
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
encoder_attention_mask,
|
992
|
-
**ckpt_kwargs,
|
1045
|
+
encoder_hidden_states=encoder_hidden_states,
|
1046
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1047
|
+
attention_mask=attention_mask,
|
1048
|
+
encoder_attention_mask=encoder_attention_mask,
|
1049
|
+
return_dict=False,
|
993
1050
|
)[0]
|
994
1051
|
else:
|
995
1052
|
hidden_states = resnet(hidden_states, temb)
|
@@ -1656,13 +1713,12 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1656
1713
|
return custom_forward
|
1657
1714
|
|
1658
1715
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1659
|
-
hidden_states =
|
1660
|
-
create_custom_forward(attn, return_dict=False),
|
1716
|
+
hidden_states = attn(
|
1661
1717
|
hidden_states,
|
1662
|
-
encoder_hidden_states,
|
1663
|
-
mask,
|
1664
|
-
cross_attention_kwargs,
|
1665
|
-
)
|
1718
|
+
encoder_hidden_states=encoder_hidden_states,
|
1719
|
+
attention_mask=mask,
|
1720
|
+
**cross_attention_kwargs,
|
1721
|
+
)
|
1666
1722
|
else:
|
1667
1723
|
hidden_states = resnet(hidden_states, temb)
|
1668
1724
|
|
@@ -1857,15 +1913,13 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
1857
1913
|
temb,
|
1858
1914
|
**ckpt_kwargs,
|
1859
1915
|
)
|
1860
|
-
hidden_states =
|
1861
|
-
create_custom_forward(attn, return_dict=False),
|
1916
|
+
hidden_states = attn(
|
1862
1917
|
hidden_states,
|
1863
|
-
encoder_hidden_states,
|
1864
|
-
temb,
|
1865
|
-
attention_mask,
|
1866
|
-
cross_attention_kwargs,
|
1867
|
-
encoder_attention_mask,
|
1868
|
-
**ckpt_kwargs,
|
1918
|
+
encoder_hidden_states=encoder_hidden_states,
|
1919
|
+
emb=temb,
|
1920
|
+
attention_mask=attention_mask,
|
1921
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1922
|
+
encoder_attention_mask=encoder_attention_mask,
|
1869
1923
|
)
|
1870
1924
|
else:
|
1871
1925
|
hidden_states = resnet(hidden_states, temb)
|
@@ -2022,6 +2076,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2022
2076
|
use_linear_projection=False,
|
2023
2077
|
only_cross_attention=False,
|
2024
2078
|
upcast_attention=False,
|
2079
|
+
attention_type="default",
|
2025
2080
|
):
|
2026
2081
|
super().__init__()
|
2027
2082
|
resnets = []
|
@@ -2060,6 +2115,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2060
2115
|
use_linear_projection=use_linear_projection,
|
2061
2116
|
only_cross_attention=only_cross_attention,
|
2062
2117
|
upcast_attention=upcast_attention,
|
2118
|
+
attention_type=attention_type,
|
2063
2119
|
)
|
2064
2120
|
)
|
2065
2121
|
else:
|
@@ -2118,16 +2174,13 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2118
2174
|
temb,
|
2119
2175
|
**ckpt_kwargs,
|
2120
2176
|
)
|
2121
|
-
hidden_states =
|
2122
|
-
create_custom_forward(attn, return_dict=False),
|
2177
|
+
hidden_states = attn(
|
2123
2178
|
hidden_states,
|
2124
|
-
encoder_hidden_states,
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2129
|
-
encoder_attention_mask,
|
2130
|
-
**ckpt_kwargs,
|
2179
|
+
encoder_hidden_states=encoder_hidden_states,
|
2180
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
2181
|
+
attention_mask=attention_mask,
|
2182
|
+
encoder_attention_mask=encoder_attention_mask,
|
2183
|
+
return_dict=False,
|
2131
2184
|
)[0]
|
2132
2185
|
else:
|
2133
2186
|
hidden_states = resnet(hidden_states, temb)
|
@@ -2817,13 +2870,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
2817
2870
|
return custom_forward
|
2818
2871
|
|
2819
2872
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
2820
|
-
hidden_states =
|
2821
|
-
create_custom_forward(attn, return_dict=False),
|
2873
|
+
hidden_states = attn(
|
2822
2874
|
hidden_states,
|
2823
|
-
encoder_hidden_states,
|
2824
|
-
mask,
|
2825
|
-
cross_attention_kwargs,
|
2826
|
-
)
|
2875
|
+
encoder_hidden_states=encoder_hidden_states,
|
2876
|
+
attention_mask=mask,
|
2877
|
+
**cross_attention_kwargs,
|
2878
|
+
)
|
2827
2879
|
else:
|
2828
2880
|
hidden_states = resnet(hidden_states, temb)
|
2829
2881
|
|
@@ -3039,16 +3091,14 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3039
3091
|
temb,
|
3040
3092
|
**ckpt_kwargs,
|
3041
3093
|
)
|
3042
|
-
hidden_states =
|
3043
|
-
create_custom_forward(attn, return_dict=False),
|
3094
|
+
hidden_states = attn(
|
3044
3095
|
hidden_states,
|
3045
|
-
encoder_hidden_states,
|
3046
|
-
temb,
|
3047
|
-
attention_mask,
|
3048
|
-
cross_attention_kwargs,
|
3049
|
-
encoder_attention_mask,
|
3050
|
-
|
3051
|
-
)[0]
|
3096
|
+
encoder_hidden_states=encoder_hidden_states,
|
3097
|
+
emb=temb,
|
3098
|
+
attention_mask=attention_mask,
|
3099
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
3100
|
+
encoder_attention_mask=encoder_attention_mask,
|
3101
|
+
)
|
3052
3102
|
else:
|
3053
3103
|
hidden_states = resnet(hidden_states, temb)
|
3054
3104
|
hidden_states = attn(
|
@@ -28,6 +28,7 @@ from .embeddings import (
|
|
28
28
|
ImageHintTimeEmbedding,
|
29
29
|
ImageProjection,
|
30
30
|
ImageTimeEmbedding,
|
31
|
+
PositionNet,
|
31
32
|
TextImageProjection,
|
32
33
|
TextImageTimeEmbedding,
|
33
34
|
TextTimeEmbedding,
|
@@ -36,12 +37,8 @@ from .embeddings import (
|
|
36
37
|
)
|
37
38
|
from .modeling_utils import ModelMixin
|
38
39
|
from .unet_2d_blocks import (
|
39
|
-
CrossAttnDownBlock2D,
|
40
|
-
CrossAttnUpBlock2D,
|
41
|
-
DownBlock2D,
|
42
40
|
UNetMidBlock2DCrossAttn,
|
43
41
|
UNetMidBlock2DSimpleCrossAttn,
|
44
|
-
UpBlock2D,
|
45
42
|
get_down_block,
|
46
43
|
get_up_block,
|
47
44
|
)
|
@@ -202,6 +199,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
202
199
|
conv_in_kernel: int = 3,
|
203
200
|
conv_out_kernel: int = 3,
|
204
201
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
202
|
+
attention_type: str = "default",
|
205
203
|
class_embeddings_concat: bool = False,
|
206
204
|
mid_block_only_cross_attention: Optional[bool] = None,
|
207
205
|
cross_attention_norm: Optional[str] = None,
|
@@ -450,6 +448,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
450
448
|
only_cross_attention=only_cross_attention[i],
|
451
449
|
upcast_attention=upcast_attention,
|
452
450
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
451
|
+
attention_type=attention_type,
|
453
452
|
resnet_skip_time_act=resnet_skip_time_act,
|
454
453
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
455
454
|
cross_attention_norm=cross_attention_norm,
|
@@ -473,6 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
473
472
|
dual_cross_attention=dual_cross_attention,
|
474
473
|
use_linear_projection=use_linear_projection,
|
475
474
|
upcast_attention=upcast_attention,
|
475
|
+
attention_type=attention_type,
|
476
476
|
)
|
477
477
|
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
478
478
|
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
@@ -539,6 +539,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
539
539
|
only_cross_attention=only_cross_attention[i],
|
540
540
|
upcast_attention=upcast_attention,
|
541
541
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
542
|
+
attention_type=attention_type,
|
542
543
|
resnet_skip_time_act=resnet_skip_time_act,
|
543
544
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
544
545
|
cross_attention_norm=cross_attention_norm,
|
@@ -564,6 +565,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
564
565
|
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
565
566
|
)
|
566
567
|
|
568
|
+
if attention_type == "gated":
|
569
|
+
positive_len = 768
|
570
|
+
if isinstance(cross_attention_dim, int):
|
571
|
+
positive_len = cross_attention_dim
|
572
|
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
573
|
+
positive_len = cross_attention_dim[0]
|
574
|
+
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
|
575
|
+
|
567
576
|
@property
|
568
577
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
569
578
|
r"""
|
@@ -694,7 +703,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
694
703
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
695
704
|
|
696
705
|
def _set_gradient_checkpointing(self, module, value=False):
|
697
|
-
if
|
706
|
+
if hasattr(module, "gradient_checkpointing"):
|
698
707
|
module.gradient_checkpointing = value
|
699
708
|
|
700
709
|
def forward(
|
@@ -899,6 +908,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
899
908
|
# 2. pre-process
|
900
909
|
sample = self.conv_in(sample)
|
901
910
|
|
911
|
+
# 2.5 GLIGEN position net
|
912
|
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
913
|
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
914
|
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
915
|
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
916
|
+
|
902
917
|
# 3. down
|
903
918
|
|
904
919
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
diffusers/models/vae.py
CHANGED
@@ -19,8 +19,9 @@ import torch
|
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
21
|
from ..utils import BaseOutput, is_torch_version, randn_tensor
|
22
|
+
from .activations import get_activation
|
22
23
|
from .attention_processor import SpatialNorm
|
23
|
-
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
24
|
+
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
|
24
25
|
|
25
26
|
|
26
27
|
@dataclass
|
@@ -686,3 +687,107 @@ class DiagonalGaussianDistribution(object):
|
|
686
687
|
|
687
688
|
def mode(self):
|
688
689
|
return self.mean
|
690
|
+
|
691
|
+
|
692
|
+
class EncoderTiny(nn.Module):
|
693
|
+
def __init__(
|
694
|
+
self,
|
695
|
+
in_channels: int,
|
696
|
+
out_channels: int,
|
697
|
+
num_blocks: int,
|
698
|
+
block_out_channels: int,
|
699
|
+
act_fn: str,
|
700
|
+
):
|
701
|
+
super().__init__()
|
702
|
+
|
703
|
+
layers = []
|
704
|
+
for i, num_block in enumerate(num_blocks):
|
705
|
+
num_channels = block_out_channels[i]
|
706
|
+
|
707
|
+
if i == 0:
|
708
|
+
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
709
|
+
else:
|
710
|
+
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
|
711
|
+
|
712
|
+
for _ in range(num_block):
|
713
|
+
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
714
|
+
|
715
|
+
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
|
716
|
+
|
717
|
+
self.layers = nn.Sequential(*layers)
|
718
|
+
self.gradient_checkpointing = False
|
719
|
+
|
720
|
+
def forward(self, x):
|
721
|
+
if self.training and self.gradient_checkpointing:
|
722
|
+
|
723
|
+
def create_custom_forward(module):
|
724
|
+
def custom_forward(*inputs):
|
725
|
+
return module(*inputs)
|
726
|
+
|
727
|
+
return custom_forward
|
728
|
+
|
729
|
+
if is_torch_version(">=", "1.11.0"):
|
730
|
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
731
|
+
else:
|
732
|
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
733
|
+
|
734
|
+
else:
|
735
|
+
x = self.layers(x)
|
736
|
+
|
737
|
+
return x
|
738
|
+
|
739
|
+
|
740
|
+
class DecoderTiny(nn.Module):
|
741
|
+
def __init__(
|
742
|
+
self,
|
743
|
+
in_channels: int,
|
744
|
+
out_channels: int,
|
745
|
+
num_blocks: int,
|
746
|
+
block_out_channels: int,
|
747
|
+
upsampling_scaling_factor: int,
|
748
|
+
act_fn: str,
|
749
|
+
):
|
750
|
+
super().__init__()
|
751
|
+
|
752
|
+
layers = [
|
753
|
+
nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
|
754
|
+
get_activation(act_fn),
|
755
|
+
]
|
756
|
+
|
757
|
+
for i, num_block in enumerate(num_blocks):
|
758
|
+
is_final_block = i == (len(num_blocks) - 1)
|
759
|
+
num_channels = block_out_channels[i]
|
760
|
+
|
761
|
+
for _ in range(num_block):
|
762
|
+
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
763
|
+
|
764
|
+
if not is_final_block:
|
765
|
+
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
766
|
+
|
767
|
+
conv_out_channel = num_channels if not is_final_block else out_channels
|
768
|
+
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
|
769
|
+
|
770
|
+
self.layers = nn.Sequential(*layers)
|
771
|
+
self.gradient_checkpointing = False
|
772
|
+
|
773
|
+
def forward(self, x):
|
774
|
+
# Clamp.
|
775
|
+
x = torch.tanh(x / 3) * 3
|
776
|
+
|
777
|
+
if self.training and self.gradient_checkpointing:
|
778
|
+
|
779
|
+
def create_custom_forward(module):
|
780
|
+
def custom_forward(*inputs):
|
781
|
+
return module(*inputs)
|
782
|
+
|
783
|
+
return custom_forward
|
784
|
+
|
785
|
+
if is_torch_version(">=", "1.11.0"):
|
786
|
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
787
|
+
else:
|
788
|
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
789
|
+
|
790
|
+
else:
|
791
|
+
x = self.layers(x)
|
792
|
+
|
793
|
+
return x
|
diffusers/pipelines/__init__.py
CHANGED
@@ -90,6 +90,7 @@ else:
|
|
90
90
|
StableDiffusionAttendAndExcitePipeline,
|
91
91
|
StableDiffusionDepth2ImgPipeline,
|
92
92
|
StableDiffusionDiffEditPipeline,
|
93
|
+
StableDiffusionGLIGENPipeline,
|
93
94
|
StableDiffusionImageVariationPipeline,
|
94
95
|
StableDiffusionImg2ImgPipeline,
|
95
96
|
StableDiffusionInpaintPipeline,
|
@@ -334,7 +334,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
|
334
334
|
)
|
335
335
|
prompt_embeds = prompt_embeds[0]
|
336
336
|
|
337
|
-
|
337
|
+
if self.text_encoder is not None:
|
338
|
+
prompt_embeds_dtype = self.text_encoder.dtype
|
339
|
+
elif self.unet is not None:
|
340
|
+
prompt_embeds_dtype = self.unet.dtype
|
341
|
+
else:
|
342
|
+
prompt_embeds_dtype = prompt_embeds.dtype
|
343
|
+
|
344
|
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
338
345
|
|
339
346
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
340
347
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
@@ -390,7 +397,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
|
390
397
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
391
398
|
seq_len = negative_prompt_embeds.shape[1]
|
392
399
|
|
393
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=
|
400
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
394
401
|
|
395
402
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
396
403
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
@@ -585,7 +592,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
|
585
592
|
every step.
|
586
593
|
cross_attention_kwargs (`dict`, *optional*):
|
587
594
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
588
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
595
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
589
596
|
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
590
597
|
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
591
598
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
@@ -335,7 +335,14 @@ class AltDiffusionImg2ImgPipeline(
|
|
335
335
|
)
|
336
336
|
prompt_embeds = prompt_embeds[0]
|
337
337
|
|
338
|
-
|
338
|
+
if self.text_encoder is not None:
|
339
|
+
prompt_embeds_dtype = self.text_encoder.dtype
|
340
|
+
elif self.unet is not None:
|
341
|
+
prompt_embeds_dtype = self.unet.dtype
|
342
|
+
else:
|
343
|
+
prompt_embeds_dtype = prompt_embeds.dtype
|
344
|
+
|
345
|
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
339
346
|
|
340
347
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
341
348
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
@@ -391,7 +398,7 @@ class AltDiffusionImg2ImgPipeline(
|
|
391
398
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
392
399
|
seq_len = negative_prompt_embeds.shape[1]
|
393
400
|
|
394
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=
|
401
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
395
402
|
|
396
403
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
397
404
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
@@ -634,7 +641,7 @@ class AltDiffusionImg2ImgPipeline(
|
|
634
641
|
every step.
|
635
642
|
cross_attention_kwargs (`dict`, *optional*):
|
636
643
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
637
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
644
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
638
645
|
|
639
646
|
Examples:
|
640
647
|
|
@@ -428,7 +428,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
|
428
428
|
every step.
|
429
429
|
cross_attention_kwargs (`dict`, *optional*):
|
430
430
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
431
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
431
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
432
432
|
output_type (`str`, *optional*, defaults to `"np"`):
|
433
433
|
The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or
|
434
434
|
`"pt"` to return a PyTorch `torch.Tensor` object.
|