diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. diffusers/__init__.py +3 -1
  2. diffusers/commands/fp16_safetensors.py +2 -7
  3. diffusers/configuration_utils.py +23 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/loaders.py +62 -64
  6. diffusers/models/__init__.py +1 -0
  7. diffusers/models/activations.py +2 -0
  8. diffusers/models/attention.py +45 -1
  9. diffusers/models/autoencoder_tiny.py +193 -0
  10. diffusers/models/controlnet.py +1 -1
  11. diffusers/models/embeddings.py +56 -0
  12. diffusers/models/lora.py +0 -6
  13. diffusers/models/modeling_flax_utils.py +28 -2
  14. diffusers/models/modeling_utils.py +33 -16
  15. diffusers/models/transformer_2d.py +26 -9
  16. diffusers/models/unet_1d.py +2 -2
  17. diffusers/models/unet_2d_blocks.py +106 -56
  18. diffusers/models/unet_2d_condition.py +20 -5
  19. diffusers/models/vae.py +106 -1
  20. diffusers/pipelines/__init__.py +1 -0
  21. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
  22. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
  23. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  24. diffusers/pipelines/auto_pipeline.py +33 -43
  25. diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
  26. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
  27. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
  28. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
  29. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
  30. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
  31. diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
  32. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
  33. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
  34. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
  35. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
  36. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
  37. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
  38. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
  39. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  40. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  41. diffusers/pipelines/pipeline_flax_utils.py +41 -4
  42. diffusers/pipelines/pipeline_utils.py +60 -16
  43. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
  44. diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  45. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
  46. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
  47. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
  48. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
  49. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
  50. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
  51. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
  52. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
  53. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
  54. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
  55. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
  56. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
  57. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
  58. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
  59. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
  60. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
  61. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
  65. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
  66. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
  67. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
  68. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
  69. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
  70. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
  71. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
  72. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
  73. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
  74. diffusers/schedulers/scheduling_consistency_models.py +70 -57
  75. diffusers/schedulers/scheduling_ddim.py +76 -71
  76. diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
  77. diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
  78. diffusers/schedulers/scheduling_ddpm.py +68 -67
  79. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
  80. diffusers/schedulers/scheduling_deis_multistep.py +93 -85
  81. diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
  82. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
  83. diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
  84. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
  85. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
  86. diffusers/schedulers/scheduling_euler_discrete.py +63 -56
  87. diffusers/schedulers/scheduling_heun_discrete.py +57 -45
  88. diffusers/schedulers/scheduling_ipndm.py +27 -22
  89. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
  90. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
  91. diffusers/schedulers/scheduling_karras_ve.py +55 -45
  92. diffusers/schedulers/scheduling_lms_discrete.py +58 -52
  93. diffusers/schedulers/scheduling_pndm.py +77 -62
  94. diffusers/schedulers/scheduling_repaint.py +56 -38
  95. diffusers/schedulers/scheduling_sde_ve.py +62 -50
  96. diffusers/schedulers/scheduling_sde_vp.py +32 -11
  97. diffusers/schedulers/scheduling_unclip.py +3 -3
  98. diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
  99. diffusers/schedulers/scheduling_utils.py +41 -35
  100. diffusers/schedulers/scheduling_utils_flax.py +8 -2
  101. diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
  102. diffusers/utils/__init__.py +2 -2
  103. diffusers/utils/dummy_pt_objects.py +15 -0
  104. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  105. diffusers/utils/hub_utils.py +105 -2
  106. diffusers/utils/import_utils.py +0 -4
  107. diffusers/utils/pil_utils.py +19 -0
  108. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
  109. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
  110. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
  111. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
  112. diffusers/models/cross_attention.py +0 -94
  113. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
  114. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import torch.nn.functional as F
19
19
  from torch import nn
20
20
 
21
21
  from ..utils import is_torch_version, logging
22
+ from .activations import get_activation
22
23
  from .attention import AdaGroupNorm
23
24
  from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
25
  from .dual_transformer_2d import DualTransformer2DModel
@@ -48,6 +49,7 @@ def get_down_block(
48
49
  only_cross_attention=False,
49
50
  upcast_attention=False,
50
51
  resnet_time_scale_shift="default",
52
+ attention_type="default",
51
53
  resnet_skip_time_act=False,
52
54
  resnet_out_scale_factor=1.0,
53
55
  cross_attention_norm=None,
@@ -128,6 +130,7 @@ def get_down_block(
128
130
  only_cross_attention=only_cross_attention,
129
131
  upcast_attention=upcast_attention,
130
132
  resnet_time_scale_shift=resnet_time_scale_shift,
133
+ attention_type=attention_type,
131
134
  )
132
135
  elif down_block_type == "SimpleCrossAttnDownBlock2D":
133
136
  if cross_attention_dim is None:
@@ -243,6 +246,7 @@ def get_up_block(
243
246
  only_cross_attention=False,
244
247
  upcast_attention=False,
245
248
  resnet_time_scale_shift="default",
249
+ attention_type="default",
246
250
  resnet_skip_time_act=False,
247
251
  resnet_out_scale_factor=1.0,
248
252
  cross_attention_norm=None,
@@ -306,6 +310,7 @@ def get_up_block(
306
310
  only_cross_attention=only_cross_attention,
307
311
  upcast_attention=upcast_attention,
308
312
  resnet_time_scale_shift=resnet_time_scale_shift,
313
+ attention_type=attention_type,
309
314
  )
310
315
  elif up_block_type == "SimpleCrossAttnUpBlock2D":
311
316
  if cross_attention_dim is None:
@@ -423,6 +428,28 @@ def get_up_block(
423
428
  raise ValueError(f"{up_block_type} does not exist.")
424
429
 
425
430
 
431
+ class AutoencoderTinyBlock(nn.Module):
432
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
433
+ super().__init__()
434
+ act_fn = get_activation(act_fn)
435
+ self.conv = nn.Sequential(
436
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
437
+ act_fn,
438
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
439
+ act_fn,
440
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
441
+ )
442
+ self.skip = (
443
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
444
+ if in_channels != out_channels
445
+ else nn.Identity()
446
+ )
447
+ self.fuse = nn.ReLU()
448
+
449
+ def forward(self, x):
450
+ return self.fuse(self.conv(x) + self.skip(x))
451
+
452
+
426
453
  class UNetMidBlock2D(nn.Module):
427
454
  def __init__(
428
455
  self,
@@ -533,6 +560,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
533
560
  dual_cross_attention=False,
534
561
  use_linear_projection=False,
535
562
  upcast_attention=False,
563
+ attention_type="default",
536
564
  ):
537
565
  super().__init__()
538
566
 
@@ -569,6 +597,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
569
597
  norm_num_groups=resnet_groups,
570
598
  use_linear_projection=use_linear_projection,
571
599
  upcast_attention=upcast_attention,
600
+ attention_type=attention_type,
572
601
  )
573
602
  )
574
603
  else:
@@ -600,6 +629,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
600
629
  self.attentions = nn.ModuleList(attentions)
601
630
  self.resnets = nn.ModuleList(resnets)
602
631
 
632
+ self.gradient_checkpointing = False
633
+
603
634
  def forward(
604
635
  self,
605
636
  hidden_states: torch.FloatTensor,
@@ -611,15 +642,42 @@ class UNetMidBlock2DCrossAttn(nn.Module):
611
642
  ) -> torch.FloatTensor:
612
643
  hidden_states = self.resnets[0](hidden_states, temb)
613
644
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
614
- hidden_states = attn(
615
- hidden_states,
616
- encoder_hidden_states=encoder_hidden_states,
617
- cross_attention_kwargs=cross_attention_kwargs,
618
- attention_mask=attention_mask,
619
- encoder_attention_mask=encoder_attention_mask,
620
- return_dict=False,
621
- )[0]
622
- hidden_states = resnet(hidden_states, temb)
645
+ if self.training and self.gradient_checkpointing:
646
+
647
+ def create_custom_forward(module, return_dict=None):
648
+ def custom_forward(*inputs):
649
+ if return_dict is not None:
650
+ return module(*inputs, return_dict=return_dict)
651
+ else:
652
+ return module(*inputs)
653
+
654
+ return custom_forward
655
+
656
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
657
+ hidden_states = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )[0]
665
+ hidden_states = torch.utils.checkpoint.checkpoint(
666
+ create_custom_forward(resnet),
667
+ hidden_states,
668
+ temb,
669
+ **ckpt_kwargs,
670
+ )
671
+ else:
672
+ hidden_states = attn(
673
+ hidden_states,
674
+ encoder_hidden_states=encoder_hidden_states,
675
+ cross_attention_kwargs=cross_attention_kwargs,
676
+ attention_mask=attention_mask,
677
+ encoder_attention_mask=encoder_attention_mask,
678
+ return_dict=False,
679
+ )[0]
680
+ hidden_states = resnet(hidden_states, temb)
623
681
 
624
682
  return hidden_states
625
683
 
@@ -882,6 +940,7 @@ class CrossAttnDownBlock2D(nn.Module):
882
940
  use_linear_projection=False,
883
941
  only_cross_attention=False,
884
942
  upcast_attention=False,
943
+ attention_type="default",
885
944
  ):
886
945
  super().__init__()
887
946
  resnets = []
@@ -918,6 +977,7 @@ class CrossAttnDownBlock2D(nn.Module):
918
977
  use_linear_projection=use_linear_projection,
919
978
  only_cross_attention=only_cross_attention,
920
979
  upcast_attention=upcast_attention,
980
+ attention_type=attention_type,
921
981
  )
922
982
  )
923
983
  else:
@@ -980,16 +1040,13 @@ class CrossAttnDownBlock2D(nn.Module):
980
1040
  temb,
981
1041
  **ckpt_kwargs,
982
1042
  )
983
- hidden_states = torch.utils.checkpoint.checkpoint(
984
- create_custom_forward(attn, return_dict=False),
1043
+ hidden_states = attn(
985
1044
  hidden_states,
986
- encoder_hidden_states,
987
- None, # timestep
988
- None, # class_labels
989
- cross_attention_kwargs,
990
- attention_mask,
991
- encoder_attention_mask,
992
- **ckpt_kwargs,
1045
+ encoder_hidden_states=encoder_hidden_states,
1046
+ cross_attention_kwargs=cross_attention_kwargs,
1047
+ attention_mask=attention_mask,
1048
+ encoder_attention_mask=encoder_attention_mask,
1049
+ return_dict=False,
993
1050
  )[0]
994
1051
  else:
995
1052
  hidden_states = resnet(hidden_states, temb)
@@ -1656,13 +1713,12 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1656
1713
  return custom_forward
1657
1714
 
1658
1715
  hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1659
- hidden_states = torch.utils.checkpoint.checkpoint(
1660
- create_custom_forward(attn, return_dict=False),
1716
+ hidden_states = attn(
1661
1717
  hidden_states,
1662
- encoder_hidden_states,
1663
- mask,
1664
- cross_attention_kwargs,
1665
- )[0]
1718
+ encoder_hidden_states=encoder_hidden_states,
1719
+ attention_mask=mask,
1720
+ **cross_attention_kwargs,
1721
+ )
1666
1722
  else:
1667
1723
  hidden_states = resnet(hidden_states, temb)
1668
1724
 
@@ -1857,15 +1913,13 @@ class KCrossAttnDownBlock2D(nn.Module):
1857
1913
  temb,
1858
1914
  **ckpt_kwargs,
1859
1915
  )
1860
- hidden_states = torch.utils.checkpoint.checkpoint(
1861
- create_custom_forward(attn, return_dict=False),
1916
+ hidden_states = attn(
1862
1917
  hidden_states,
1863
- encoder_hidden_states,
1864
- temb,
1865
- attention_mask,
1866
- cross_attention_kwargs,
1867
- encoder_attention_mask,
1868
- **ckpt_kwargs,
1918
+ encoder_hidden_states=encoder_hidden_states,
1919
+ emb=temb,
1920
+ attention_mask=attention_mask,
1921
+ cross_attention_kwargs=cross_attention_kwargs,
1922
+ encoder_attention_mask=encoder_attention_mask,
1869
1923
  )
1870
1924
  else:
1871
1925
  hidden_states = resnet(hidden_states, temb)
@@ -2022,6 +2076,7 @@ class CrossAttnUpBlock2D(nn.Module):
2022
2076
  use_linear_projection=False,
2023
2077
  only_cross_attention=False,
2024
2078
  upcast_attention=False,
2079
+ attention_type="default",
2025
2080
  ):
2026
2081
  super().__init__()
2027
2082
  resnets = []
@@ -2060,6 +2115,7 @@ class CrossAttnUpBlock2D(nn.Module):
2060
2115
  use_linear_projection=use_linear_projection,
2061
2116
  only_cross_attention=only_cross_attention,
2062
2117
  upcast_attention=upcast_attention,
2118
+ attention_type=attention_type,
2063
2119
  )
2064
2120
  )
2065
2121
  else:
@@ -2118,16 +2174,13 @@ class CrossAttnUpBlock2D(nn.Module):
2118
2174
  temb,
2119
2175
  **ckpt_kwargs,
2120
2176
  )
2121
- hidden_states = torch.utils.checkpoint.checkpoint(
2122
- create_custom_forward(attn, return_dict=False),
2177
+ hidden_states = attn(
2123
2178
  hidden_states,
2124
- encoder_hidden_states,
2125
- None, # timestep
2126
- None, # class_labels
2127
- cross_attention_kwargs,
2128
- attention_mask,
2129
- encoder_attention_mask,
2130
- **ckpt_kwargs,
2179
+ encoder_hidden_states=encoder_hidden_states,
2180
+ cross_attention_kwargs=cross_attention_kwargs,
2181
+ attention_mask=attention_mask,
2182
+ encoder_attention_mask=encoder_attention_mask,
2183
+ return_dict=False,
2131
2184
  )[0]
2132
2185
  else:
2133
2186
  hidden_states = resnet(hidden_states, temb)
@@ -2817,13 +2870,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
2817
2870
  return custom_forward
2818
2871
 
2819
2872
  hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2820
- hidden_states = torch.utils.checkpoint.checkpoint(
2821
- create_custom_forward(attn, return_dict=False),
2873
+ hidden_states = attn(
2822
2874
  hidden_states,
2823
- encoder_hidden_states,
2824
- mask,
2825
- cross_attention_kwargs,
2826
- )[0]
2875
+ encoder_hidden_states=encoder_hidden_states,
2876
+ attention_mask=mask,
2877
+ **cross_attention_kwargs,
2878
+ )
2827
2879
  else:
2828
2880
  hidden_states = resnet(hidden_states, temb)
2829
2881
 
@@ -3039,16 +3091,14 @@ class KCrossAttnUpBlock2D(nn.Module):
3039
3091
  temb,
3040
3092
  **ckpt_kwargs,
3041
3093
  )
3042
- hidden_states = torch.utils.checkpoint.checkpoint(
3043
- create_custom_forward(attn, return_dict=False),
3094
+ hidden_states = attn(
3044
3095
  hidden_states,
3045
- encoder_hidden_states,
3046
- temb,
3047
- attention_mask,
3048
- cross_attention_kwargs,
3049
- encoder_attention_mask,
3050
- **ckpt_kwargs,
3051
- )[0]
3096
+ encoder_hidden_states=encoder_hidden_states,
3097
+ emb=temb,
3098
+ attention_mask=attention_mask,
3099
+ cross_attention_kwargs=cross_attention_kwargs,
3100
+ encoder_attention_mask=encoder_attention_mask,
3101
+ )
3052
3102
  else:
3053
3103
  hidden_states = resnet(hidden_states, temb)
3054
3104
  hidden_states = attn(
@@ -28,6 +28,7 @@ from .embeddings import (
28
28
  ImageHintTimeEmbedding,
29
29
  ImageProjection,
30
30
  ImageTimeEmbedding,
31
+ PositionNet,
31
32
  TextImageProjection,
32
33
  TextImageTimeEmbedding,
33
34
  TextTimeEmbedding,
@@ -36,12 +37,8 @@ from .embeddings import (
36
37
  )
37
38
  from .modeling_utils import ModelMixin
38
39
  from .unet_2d_blocks import (
39
- CrossAttnDownBlock2D,
40
- CrossAttnUpBlock2D,
41
- DownBlock2D,
42
40
  UNetMidBlock2DCrossAttn,
43
41
  UNetMidBlock2DSimpleCrossAttn,
44
- UpBlock2D,
45
42
  get_down_block,
46
43
  get_up_block,
47
44
  )
@@ -202,6 +199,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
202
199
  conv_in_kernel: int = 3,
203
200
  conv_out_kernel: int = 3,
204
201
  projection_class_embeddings_input_dim: Optional[int] = None,
202
+ attention_type: str = "default",
205
203
  class_embeddings_concat: bool = False,
206
204
  mid_block_only_cross_attention: Optional[bool] = None,
207
205
  cross_attention_norm: Optional[str] = None,
@@ -450,6 +448,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
450
448
  only_cross_attention=only_cross_attention[i],
451
449
  upcast_attention=upcast_attention,
452
450
  resnet_time_scale_shift=resnet_time_scale_shift,
451
+ attention_type=attention_type,
453
452
  resnet_skip_time_act=resnet_skip_time_act,
454
453
  resnet_out_scale_factor=resnet_out_scale_factor,
455
454
  cross_attention_norm=cross_attention_norm,
@@ -473,6 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
473
472
  dual_cross_attention=dual_cross_attention,
474
473
  use_linear_projection=use_linear_projection,
475
474
  upcast_attention=upcast_attention,
475
+ attention_type=attention_type,
476
476
  )
477
477
  elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
478
478
  self.mid_block = UNetMidBlock2DSimpleCrossAttn(
@@ -539,6 +539,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
539
539
  only_cross_attention=only_cross_attention[i],
540
540
  upcast_attention=upcast_attention,
541
541
  resnet_time_scale_shift=resnet_time_scale_shift,
542
+ attention_type=attention_type,
542
543
  resnet_skip_time_act=resnet_skip_time_act,
543
544
  resnet_out_scale_factor=resnet_out_scale_factor,
544
545
  cross_attention_norm=cross_attention_norm,
@@ -564,6 +565,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
564
565
  block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
565
566
  )
566
567
 
568
+ if attention_type == "gated":
569
+ positive_len = 768
570
+ if isinstance(cross_attention_dim, int):
571
+ positive_len = cross_attention_dim
572
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
573
+ positive_len = cross_attention_dim[0]
574
+ self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
575
+
567
576
  @property
568
577
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
569
578
  r"""
@@ -694,7 +703,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
694
703
  fn_recursive_set_attention_slice(module, reversed_slice_size)
695
704
 
696
705
  def _set_gradient_checkpointing(self, module, value=False):
697
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
706
+ if hasattr(module, "gradient_checkpointing"):
698
707
  module.gradient_checkpointing = value
699
708
 
700
709
  def forward(
@@ -899,6 +908,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
899
908
  # 2. pre-process
900
909
  sample = self.conv_in(sample)
901
910
 
911
+ # 2.5 GLIGEN position net
912
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
913
+ cross_attention_kwargs = cross_attention_kwargs.copy()
914
+ gligen_args = cross_attention_kwargs.pop("gligen")
915
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
916
+
902
917
  # 3. down
903
918
 
904
919
  is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
diffusers/models/vae.py CHANGED
@@ -19,8 +19,9 @@ import torch
19
19
  import torch.nn as nn
20
20
 
21
21
  from ..utils import BaseOutput, is_torch_version, randn_tensor
22
+ from .activations import get_activation
22
23
  from .attention_processor import SpatialNorm
23
- from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
24
+ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
24
25
 
25
26
 
26
27
  @dataclass
@@ -686,3 +687,107 @@ class DiagonalGaussianDistribution(object):
686
687
 
687
688
  def mode(self):
688
689
  return self.mean
690
+
691
+
692
+ class EncoderTiny(nn.Module):
693
+ def __init__(
694
+ self,
695
+ in_channels: int,
696
+ out_channels: int,
697
+ num_blocks: int,
698
+ block_out_channels: int,
699
+ act_fn: str,
700
+ ):
701
+ super().__init__()
702
+
703
+ layers = []
704
+ for i, num_block in enumerate(num_blocks):
705
+ num_channels = block_out_channels[i]
706
+
707
+ if i == 0:
708
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
709
+ else:
710
+ layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
711
+
712
+ for _ in range(num_block):
713
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
714
+
715
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
716
+
717
+ self.layers = nn.Sequential(*layers)
718
+ self.gradient_checkpointing = False
719
+
720
+ def forward(self, x):
721
+ if self.training and self.gradient_checkpointing:
722
+
723
+ def create_custom_forward(module):
724
+ def custom_forward(*inputs):
725
+ return module(*inputs)
726
+
727
+ return custom_forward
728
+
729
+ if is_torch_version(">=", "1.11.0"):
730
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
731
+ else:
732
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
733
+
734
+ else:
735
+ x = self.layers(x)
736
+
737
+ return x
738
+
739
+
740
+ class DecoderTiny(nn.Module):
741
+ def __init__(
742
+ self,
743
+ in_channels: int,
744
+ out_channels: int,
745
+ num_blocks: int,
746
+ block_out_channels: int,
747
+ upsampling_scaling_factor: int,
748
+ act_fn: str,
749
+ ):
750
+ super().__init__()
751
+
752
+ layers = [
753
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
754
+ get_activation(act_fn),
755
+ ]
756
+
757
+ for i, num_block in enumerate(num_blocks):
758
+ is_final_block = i == (len(num_blocks) - 1)
759
+ num_channels = block_out_channels[i]
760
+
761
+ for _ in range(num_block):
762
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
763
+
764
+ if not is_final_block:
765
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
766
+
767
+ conv_out_channel = num_channels if not is_final_block else out_channels
768
+ layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
769
+
770
+ self.layers = nn.Sequential(*layers)
771
+ self.gradient_checkpointing = False
772
+
773
+ def forward(self, x):
774
+ # Clamp.
775
+ x = torch.tanh(x / 3) * 3
776
+
777
+ if self.training and self.gradient_checkpointing:
778
+
779
+ def create_custom_forward(module):
780
+ def custom_forward(*inputs):
781
+ return module(*inputs)
782
+
783
+ return custom_forward
784
+
785
+ if is_torch_version(">=", "1.11.0"):
786
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
787
+ else:
788
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
789
+
790
+ else:
791
+ x = self.layers(x)
792
+
793
+ return x
@@ -90,6 +90,7 @@ else:
90
90
  StableDiffusionAttendAndExcitePipeline,
91
91
  StableDiffusionDepth2ImgPipeline,
92
92
  StableDiffusionDiffEditPipeline,
93
+ StableDiffusionGLIGENPipeline,
93
94
  StableDiffusionImageVariationPipeline,
94
95
  StableDiffusionImg2ImgPipeline,
95
96
  StableDiffusionInpaintPipeline,
@@ -334,7 +334,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
334
334
  )
335
335
  prompt_embeds = prompt_embeds[0]
336
336
 
337
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
337
+ if self.text_encoder is not None:
338
+ prompt_embeds_dtype = self.text_encoder.dtype
339
+ elif self.unet is not None:
340
+ prompt_embeds_dtype = self.unet.dtype
341
+ else:
342
+ prompt_embeds_dtype = prompt_embeds.dtype
343
+
344
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
338
345
 
339
346
  bs_embed, seq_len, _ = prompt_embeds.shape
340
347
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -390,7 +397,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
390
397
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
391
398
  seq_len = negative_prompt_embeds.shape[1]
392
399
 
393
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
400
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
394
401
 
395
402
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
396
403
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -585,7 +592,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
585
592
  every step.
586
593
  cross_attention_kwargs (`dict`, *optional*):
587
594
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
588
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
595
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
589
596
  guidance_rescale (`float`, *optional*, defaults to 0.7):
590
597
  Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
591
598
  Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
@@ -335,7 +335,14 @@ class AltDiffusionImg2ImgPipeline(
335
335
  )
336
336
  prompt_embeds = prompt_embeds[0]
337
337
 
338
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
338
+ if self.text_encoder is not None:
339
+ prompt_embeds_dtype = self.text_encoder.dtype
340
+ elif self.unet is not None:
341
+ prompt_embeds_dtype = self.unet.dtype
342
+ else:
343
+ prompt_embeds_dtype = prompt_embeds.dtype
344
+
345
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
339
346
 
340
347
  bs_embed, seq_len, _ = prompt_embeds.shape
341
348
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -391,7 +398,7 @@ class AltDiffusionImg2ImgPipeline(
391
398
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
399
  seq_len = negative_prompt_embeds.shape[1]
393
400
 
394
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
401
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
395
402
 
396
403
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397
404
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -634,7 +641,7 @@ class AltDiffusionImg2ImgPipeline(
634
641
  every step.
635
642
  cross_attention_kwargs (`dict`, *optional*):
636
643
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
637
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
644
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
638
645
 
639
646
  Examples:
640
647
 
@@ -428,7 +428,7 @@ class AudioLDMPipeline(DiffusionPipeline):
428
428
  every step.
429
429
  cross_attention_kwargs (`dict`, *optional*):
430
430
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
431
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
431
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
432
432
  output_type (`str`, *optional*, defaults to `"np"`):
433
433
  The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or
434
434
  `"pt"` to return a PyTorch `torch.Tensor` object.