diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
22
22
  from ..models.embeddings import ImagePositionalEmbeddings
23
23
  from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
24
24
  from .attention import BasicTransformerBlock
25
- from .embeddings import CaptionProjection, PatchEmbed
25
+ from .embeddings import PatchEmbed, PixArtAlphaTextProjection
26
26
  from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
27
  from .modeling_utils import ModelMixin
28
28
  from .normalization import AdaLayerNormSingle
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
235
235
 
236
236
  self.caption_projection = None
237
237
  if caption_channels is not None:
238
- self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
238
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
239
239
 
240
240
  self.gradient_checkpointing = False
241
241
 
@@ -25,6 +25,7 @@ from .activations import get_activation
25
25
  from .attention_processor import (
26
26
  ADDED_KV_ATTENTION_PROCESSORS,
27
27
  CROSS_ATTENTION_PROCESSORS,
28
+ Attention,
28
29
  AttentionProcessor,
29
30
  AttnAddedKVProcessor,
30
31
  AttnProcessor,
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
794
795
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
795
796
  setattr(upsample_block, k, None)
796
797
 
798
+ def fuse_qkv_projections(self):
799
+ """
800
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
801
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
802
+
803
+ <Tip warning={true}>
804
+
805
+ This API is 🧪 experimental.
806
+
807
+ </Tip>
808
+ """
809
+ self.original_attn_processors = None
810
+
811
+ for _, attn_processor in self.attn_processors.items():
812
+ if "Added" in str(attn_processor.__class__.__name__):
813
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
814
+
815
+ self.original_attn_processors = self.attn_processors
816
+
817
+ for module in self.modules():
818
+ if isinstance(module, Attention):
819
+ module.fuse_projections(fuse=True)
820
+
821
+ def unfuse_qkv_projections(self):
822
+ """Disables the fused QKV projection if enabled.
823
+
824
+ <Tip warning={true}>
825
+
826
+ This API is 🧪 experimental.
827
+
828
+ </Tip>
829
+
830
+ """
831
+ if self.original_attn_processors is not None:
832
+ self.set_attn_processor(self.original_attn_processors)
833
+
797
834
  def forward(
798
835
  self,
799
836
  sample: torch.FloatTensor,
@@ -1,16 +1,28 @@
1
- import math
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
  from dataclasses import dataclass
3
16
  from typing import Dict, Tuple, Union
4
17
 
5
18
  import torch
6
- import torch.nn.functional as F
7
19
  import torch.utils.checkpoint
8
20
  from torch import nn
9
21
 
10
22
  from ..configuration_utils import ConfigMixin, register_to_config
11
23
  from ..utils import BaseOutput, logging
12
- from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
13
- from .embeddings import TimestepEmbedding
24
+ from .attention_processor import Attention, AttentionProcessor, AttnProcessor
25
+ from .embeddings import TimestepEmbedding, Timesteps
14
26
  from .modeling_utils import ModelMixin
15
27
 
16
28
 
@@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput):
22
34
  sample: torch.FloatTensor = None
23
35
 
24
36
 
25
- # TODO(Yiyi): This class needs to be removed
26
- def set_default_item(condition, item_1, item_2=None):
27
- if condition:
28
- return item_1
29
- else:
30
- return item_2
31
-
32
-
33
- # TODO(Yiyi): This class needs to be removed
34
- def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
35
- if condition:
36
- return layer_1(*args_1, **kwargs_1)
37
- else:
38
- return layer_2(*args_2, **kwargs_2)
39
-
40
-
41
- # TODO(Yiyi): This class should be removed and be replaced by Timesteps
42
- class SinusoidalPosEmb(nn.Module):
43
- def __init__(self, dim):
44
- super().__init__()
45
- self.dim = dim
46
-
47
- def forward(self, x, type_tensor=None):
48
- half_dim = self.dim // 2
49
- emb = math.log(10000) / (half_dim - 1)
50
- emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
51
- emb = x[:, None] * emb[None, :]
52
- return torch.cat((emb.sin(), emb.cos()), dim=-1)
53
-
54
-
55
37
  class Kandinsky3EncoderProj(nn.Module):
56
38
  def __init__(self, encoder_hid_dim, cross_attention_dim):
57
39
  super().__init__()
@@ -87,9 +69,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
87
69
 
88
70
  out_channels = in_channels
89
71
  init_channels = block_out_channels[0] // 2
90
- # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
91
- # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
92
- self.time_proj = SinusoidalPosEmb(init_channels)
72
+ self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
93
73
 
94
74
  self.time_embedding = TimestepEmbedding(
95
75
  init_channels,
@@ -106,7 +86,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
106
86
 
107
87
  hidden_dims = [init_channels] + list(block_out_channels)
108
88
  in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
109
- text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
89
+ text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention]
110
90
  num_blocks = len(block_out_channels) * [layers_per_block]
111
91
  layer_params = [num_blocks, text_dims, add_self_attention]
112
92
  rev_layer_params = map(reversed, layer_params)
@@ -118,7 +98,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
118
98
  zip(in_out_dims, *layer_params)
119
99
  ):
120
100
  down_sample = level != (self.num_levels - 1)
121
- cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
101
+ cat_dims.append(out_dim if level != (self.num_levels - 1) else 0)
122
102
  self.down_blocks.append(
123
103
  Kandinsky3DownSampleBlock(
124
104
  in_dim,
@@ -223,18 +203,16 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
223
203
  """
224
204
  Disables custom attention processors and sets the default attention implementation.
225
205
  """
226
- self.set_attn_processor(Kandi3AttnProcessor())
206
+ self.set_attn_processor(AttnProcessor())
227
207
 
228
208
  def _set_gradient_checkpointing(self, module, value=False):
229
209
  if hasattr(module, "gradient_checkpointing"):
230
210
  module.gradient_checkpointing = value
231
211
 
232
212
  def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
233
- # TODO(Yiyi): Clean up the following variables - these names should not be used
234
- # but instead only the ones that we pass to forward
235
- x = sample
236
- context_mask = encoder_attention_mask
237
- context = encoder_hidden_states
213
+ if encoder_attention_mask is not None:
214
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
215
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
238
216
 
239
217
  if not torch.is_tensor(timestep):
240
218
  dtype = torch.float32 if isinstance(timestep, float) else torch.int32
@@ -244,33 +222,33 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
244
222
 
245
223
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
246
224
  timestep = timestep.expand(sample.shape[0])
247
- time_embed_input = self.time_proj(timestep).to(x.dtype)
225
+ time_embed_input = self.time_proj(timestep).to(sample.dtype)
248
226
  time_embed = self.time_embedding(time_embed_input)
249
227
 
250
- context = self.encoder_hid_proj(context)
228
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
251
229
 
252
- if context is not None:
253
- time_embed = self.add_time_condition(time_embed, context, context_mask)
230
+ if encoder_hidden_states is not None:
231
+ time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
254
232
 
255
233
  hidden_states = []
256
- x = self.conv_in(x)
234
+ sample = self.conv_in(sample)
257
235
  for level, down_sample in enumerate(self.down_blocks):
258
- x = down_sample(x, time_embed, context, context_mask)
236
+ sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
259
237
  if level != self.num_levels - 1:
260
- hidden_states.append(x)
238
+ hidden_states.append(sample)
261
239
 
262
240
  for level, up_sample in enumerate(self.up_blocks):
263
241
  if level != 0:
264
- x = torch.cat([x, hidden_states.pop()], dim=1)
265
- x = up_sample(x, time_embed, context, context_mask)
242
+ sample = torch.cat([sample, hidden_states.pop()], dim=1)
243
+ sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
266
244
 
267
- x = self.conv_norm_out(x)
268
- x = self.conv_act_out(x)
269
- x = self.conv_out(x)
245
+ sample = self.conv_norm_out(sample)
246
+ sample = self.conv_act_out(sample)
247
+ sample = self.conv_out(sample)
270
248
 
271
249
  if not return_dict:
272
- return (x,)
273
- return Kandinsky3UNetOutput(sample=x)
250
+ return (sample,)
251
+ return Kandinsky3UNetOutput(sample=sample)
274
252
 
275
253
 
276
254
  class Kandinsky3UpSampleBlock(nn.Module):
@@ -290,7 +268,7 @@ class Kandinsky3UpSampleBlock(nn.Module):
290
268
  self_attention=True,
291
269
  ):
292
270
  super().__init__()
293
- up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
271
+ up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
294
272
  hidden_channels = (
295
273
  [(in_channels + cat_dim, in_channels)]
296
274
  + [(in_channels, in_channels)] * (num_blocks - 2)
@@ -303,27 +281,27 @@ class Kandinsky3UpSampleBlock(nn.Module):
303
281
  self.self_attention = self_attention
304
282
  self.context_dim = context_dim
305
283
 
306
- attentions.append(
307
- set_default_layer(
308
- self_attention,
309
- Kandinsky3AttentionBlock,
310
- (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
311
- layer_2=nn.Identity,
284
+ if self_attention:
285
+ attentions.append(
286
+ Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
312
287
  )
313
- )
288
+ else:
289
+ attentions.append(nn.Identity())
314
290
 
315
291
  for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
316
292
  resnets_in.append(
317
293
  Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
318
294
  )
319
- attentions.append(
320
- set_default_layer(
321
- context_dim is not None,
322
- Kandinsky3AttentionBlock,
323
- (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
324
- layer_2=nn.Identity,
295
+
296
+ if context_dim is not None:
297
+ attentions.append(
298
+ Kandinsky3AttentionBlock(
299
+ in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
300
+ )
325
301
  )
326
- )
302
+ else:
303
+ attentions.append(nn.Identity())
304
+
327
305
  resnets_out.append(
328
306
  Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
329
307
  )
@@ -367,29 +345,29 @@ class Kandinsky3DownSampleBlock(nn.Module):
367
345
  self.self_attention = self_attention
368
346
  self.context_dim = context_dim
369
347
 
370
- attentions.append(
371
- set_default_layer(
372
- self_attention,
373
- Kandinsky3AttentionBlock,
374
- (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
375
- layer_2=nn.Identity,
348
+ if self_attention:
349
+ attentions.append(
350
+ Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
376
351
  )
377
- )
352
+ else:
353
+ attentions.append(nn.Identity())
378
354
 
379
- up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
355
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
380
356
  hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
381
357
  for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
382
358
  resnets_in.append(
383
359
  Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
384
360
  )
385
- attentions.append(
386
- set_default_layer(
387
- context_dim is not None,
388
- Kandinsky3AttentionBlock,
389
- (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
390
- layer_2=nn.Identity,
361
+
362
+ if context_dim is not None:
363
+ attentions.append(
364
+ Kandinsky3AttentionBlock(
365
+ out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
366
+ )
391
367
  )
392
- )
368
+ else:
369
+ attentions.append(nn.Identity())
370
+
393
371
  resnets_out.append(
394
372
  Kandinsky3ResNetBlock(
395
373
  out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
@@ -431,68 +409,23 @@ class Kandinsky3ConditionalGroupNorm(nn.Module):
431
409
  return x
432
410
 
433
411
 
434
- # TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
435
- # sure we can delete it and instead just pass an attention_mask
436
- class Attention(nn.Module):
437
- def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
438
- super().__init__()
439
- assert out_channels % head_dim == 0
440
- self.num_heads = out_channels // head_dim
441
- self.scale = head_dim**-0.5
442
-
443
- # to_q
444
- self.to_q = nn.Linear(in_channels, out_channels, bias=False)
445
- # to_k
446
- self.to_k = nn.Linear(context_dim, out_channels, bias=False)
447
- # to_v
448
- self.to_v = nn.Linear(context_dim, out_channels, bias=False)
449
- processor = Kandi3AttnProcessor()
450
- self.set_processor(processor)
451
- # to_out
452
- self.to_out = nn.ModuleList([])
453
- self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
454
-
455
- def set_processor(self, processor: "AttnProcessor"): # noqa: F821
456
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
457
- # pop `processor` from `self._modules`
458
- if (
459
- hasattr(self, "processor")
460
- and isinstance(self.processor, torch.nn.Module)
461
- and not isinstance(processor, torch.nn.Module)
462
- ):
463
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
464
- self._modules.pop("processor")
465
-
466
- self.processor = processor
467
-
468
- def forward(self, x, context, context_mask=None, image_mask=None):
469
- return self.processor(
470
- self,
471
- x,
472
- context=context,
473
- context_mask=context_mask,
474
- )
475
-
476
-
477
412
  class Kandinsky3Block(nn.Module):
478
413
  def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
479
414
  super().__init__()
480
415
  self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
481
416
  self.activation = nn.SiLU()
482
- self.up_sample = set_default_layer(
483
- up_resolution is not None and up_resolution,
484
- nn.ConvTranspose2d,
485
- (in_channels, in_channels),
486
- {"kernel_size": 2, "stride": 2},
487
- )
417
+ if up_resolution is not None and up_resolution:
418
+ self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
419
+ else:
420
+ self.up_sample = nn.Identity()
421
+
488
422
  padding = int(kernel_size > 1)
489
423
  self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
490
- self.down_sample = set_default_layer(
491
- up_resolution is not None and not up_resolution,
492
- nn.Conv2d,
493
- (out_channels, out_channels),
494
- {"kernel_size": 2, "stride": 2},
495
- )
424
+
425
+ if up_resolution is not None and not up_resolution:
426
+ self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
427
+ else:
428
+ self.down_sample = nn.Identity()
496
429
 
497
430
  def forward(self, x, time_embed):
498
431
  x = self.group_norm(x, time_embed)
@@ -521,14 +454,18 @@ class Kandinsky3ResNetBlock(nn.Module):
521
454
  )
522
455
  ]
523
456
  )
524
- self.shortcut_up_sample = set_default_layer(
525
- True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
457
+ self.shortcut_up_sample = (
458
+ nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
459
+ if True in up_resolutions
460
+ else nn.Identity()
526
461
  )
527
- self.shortcut_projection = set_default_layer(
528
- in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
462
+ self.shortcut_projection = (
463
+ nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
529
464
  )
530
- self.shortcut_down_sample = set_default_layer(
531
- False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
465
+ self.shortcut_down_sample = (
466
+ nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
467
+ if False in up_resolutions
468
+ else nn.Identity()
532
469
  )
533
470
 
534
471
  def forward(self, x, time_embed):
@@ -546,9 +483,16 @@ class Kandinsky3ResNetBlock(nn.Module):
546
483
  class Kandinsky3AttentionPooling(nn.Module):
547
484
  def __init__(self, num_channels, context_dim, head_dim=64):
548
485
  super().__init__()
549
- self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
486
+ self.attention = Attention(
487
+ context_dim,
488
+ context_dim,
489
+ dim_head=head_dim,
490
+ out_dim=num_channels,
491
+ out_bias=False,
492
+ )
550
493
 
551
494
  def forward(self, x, context, context_mask=None):
495
+ context_mask = context_mask.to(dtype=context.dtype)
552
496
  context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
553
497
  return x + context.squeeze(1)
554
498
 
@@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module):
557
501
  def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
558
502
  super().__init__()
559
503
  self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
560
- self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
504
+ self.attention = Attention(
505
+ num_channels,
506
+ context_dim or num_channels,
507
+ dim_head=head_dim,
508
+ out_dim=num_channels,
509
+ out_bias=False,
510
+ )
561
511
 
562
512
  hidden_channels = expansion_ratio * num_channels
563
513
  self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
@@ -572,14 +522,10 @@ class Kandinsky3AttentionBlock(nn.Module):
572
522
  out = self.in_norm(x, time_embed)
573
523
  out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
574
524
  context = context if context is not None else out
525
+ if context_mask is not None:
526
+ context_mask = context_mask.to(dtype=context.dtype)
575
527
 
576
- if image_mask is not None:
577
- mask_height, mask_width = image_mask.shape[-2:]
578
- kernel_size = (mask_height // height, mask_width // width)
579
- image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
580
- image_mask = image_mask.reshape(image_mask.shape[0], -1)
581
-
582
- out = self.attention(out, context, context_mask, image_mask)
528
+ out = self.attention(out, context, context_mask)
583
529
  out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
584
530
  x = x + out
585
531