diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +0 -1
  4. diffusers/dependency_versions_table.py +4 -5
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +8 -7
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
  171. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
  172. diffusers/loaders.py +0 -3336
  173. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
  175. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
  176. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  from dataclasses import dataclass
16
17
  from typing import Any, Dict, List, Optional, Tuple, Union
17
18
 
@@ -22,6 +23,7 @@ import torch.utils.checkpoint
22
23
  from ..configuration_utils import ConfigMixin, register_to_config
23
24
  from ..loaders import UNet2DConditionLoadersMixin
24
25
  from ..utils import BaseOutput, logging
26
+ from .activations import get_activation
25
27
  from .attention_processor import (
26
28
  ADDED_KV_ATTENTION_PROCESSORS,
27
29
  CROSS_ATTENTION_PROCESSORS,
@@ -98,14 +100,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
98
100
  sample_size: Optional[int] = None,
99
101
  in_channels: int = 4,
100
102
  out_channels: int = 4,
101
- down_block_types: Tuple[str] = (
103
+ down_block_types: Tuple[str, ...] = (
102
104
  "CrossAttnDownBlock3D",
103
105
  "CrossAttnDownBlock3D",
104
106
  "CrossAttnDownBlock3D",
105
107
  "DownBlock3D",
106
108
  ),
107
- up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
108
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
109
+ up_block_types: Tuple[str, ...] = (
110
+ "UpBlock3D",
111
+ "CrossAttnUpBlock3D",
112
+ "CrossAttnUpBlock3D",
113
+ "CrossAttnUpBlock3D",
114
+ ),
115
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
109
116
  layers_per_block: int = 2,
110
117
  downsample_padding: int = 1,
111
118
  mid_block_scale_factor: float = 1,
@@ -173,6 +180,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
173
180
  attention_head_dim=attention_head_dim,
174
181
  in_channels=block_out_channels[0],
175
182
  num_layers=1,
183
+ norm_num_groups=norm_num_groups,
176
184
  )
177
185
 
178
186
  # class embedding
@@ -265,7 +273,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
265
273
  self.conv_norm_out = nn.GroupNorm(
266
274
  num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
267
275
  )
268
- self.conv_act = nn.SiLU()
276
+ self.conv_act = get_activation("silu")
269
277
  else:
270
278
  self.conv_norm_out = None
271
279
  self.conv_act = None
@@ -301,7 +309,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
301
309
  return processors
302
310
 
303
311
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
304
- def set_attention_slice(self, slice_size):
312
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
305
313
  r"""
306
314
  Enable sliced attention computation.
307
315
 
@@ -403,7 +411,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
403
411
  for name, module in self.named_children():
404
412
  fn_recursive_attn_processor(name, module, processor)
405
413
 
406
- def enable_forward_chunking(self, chunk_size=None, dim=0):
414
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
407
415
  """
408
416
  Sets the attention processor to use [feed forward
409
417
  chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
@@ -459,7 +467,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
459
467
 
460
468
  self.set_attn_processor(processor, _remove_lora=True)
461
469
 
462
- def _set_gradient_checkpointing(self, module, value=False):
470
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
463
471
  if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
464
472
  module.gradient_checkpointing = value
465
473
 
@@ -509,7 +517,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
509
517
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
510
518
  mid_block_additional_residual: Optional[torch.Tensor] = None,
511
519
  return_dict: bool = True,
512
- ) -> Union[UNet3DConditionOutput, Tuple]:
520
+ ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
513
521
  r"""
514
522
  The [`UNet3DConditionModel`] forward method.
515
523
 
@@ -0,0 +1,589 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+
10
+ from ..configuration_utils import ConfigMixin, register_to_config
11
+ from ..utils import BaseOutput, logging
12
+ from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
13
+ from .embeddings import TimestepEmbedding
14
+ from .modeling_utils import ModelMixin
15
+
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ @dataclass
21
+ class Kandinsky3UNetOutput(BaseOutput):
22
+ sample: torch.FloatTensor = None
23
+
24
+
25
+ # TODO(Yiyi): This class needs to be removed
26
+ def set_default_item(condition, item_1, item_2=None):
27
+ if condition:
28
+ return item_1
29
+ else:
30
+ return item_2
31
+
32
+
33
+ # TODO(Yiyi): This class needs to be removed
34
+ def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
35
+ if condition:
36
+ return layer_1(*args_1, **kwargs_1)
37
+ else:
38
+ return layer_2(*args_2, **kwargs_2)
39
+
40
+
41
+ # TODO(Yiyi): This class should be removed and be replaced by Timesteps
42
+ class SinusoidalPosEmb(nn.Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.dim = dim
46
+
47
+ def forward(self, x, type_tensor=None):
48
+ half_dim = self.dim // 2
49
+ emb = math.log(10000) / (half_dim - 1)
50
+ emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
51
+ emb = x[:, None] * emb[None, :]
52
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
53
+
54
+
55
+ class Kandinsky3EncoderProj(nn.Module):
56
+ def __init__(self, encoder_hid_dim, cross_attention_dim):
57
+ super().__init__()
58
+ self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False)
59
+ self.projection_norm = nn.LayerNorm(cross_attention_dim)
60
+
61
+ def forward(self, x):
62
+ x = self.projection_linear(x)
63
+ x = self.projection_norm(x)
64
+ return x
65
+
66
+
67
+ class Kandinsky3UNet(ModelMixin, ConfigMixin):
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ in_channels: int = 4,
72
+ time_embedding_dim: int = 1536,
73
+ groups: int = 32,
74
+ attention_head_dim: int = 64,
75
+ layers_per_block: Union[int, Tuple[int]] = 3,
76
+ block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
77
+ cross_attention_dim: Union[int, Tuple[int]] = 4096,
78
+ encoder_hid_dim: int = 4096,
79
+ ):
80
+ super().__init__()
81
+
82
+ # TOOD(Yiyi): Give better name and put into config for the following 4 parameters
83
+ expansion_ratio = 4
84
+ compression_ratio = 2
85
+ add_cross_attention = (False, True, True, True)
86
+ add_self_attention = (False, True, True, True)
87
+
88
+ out_channels = in_channels
89
+ init_channels = block_out_channels[0] // 2
90
+ # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
91
+ # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
92
+ self.time_proj = SinusoidalPosEmb(init_channels)
93
+
94
+ self.time_embedding = TimestepEmbedding(
95
+ init_channels,
96
+ time_embedding_dim,
97
+ )
98
+
99
+ self.add_time_condition = Kandinsky3AttentionPooling(
100
+ time_embedding_dim, cross_attention_dim, attention_head_dim
101
+ )
102
+
103
+ self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1)
104
+
105
+ self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim)
106
+
107
+ hidden_dims = [init_channels] + list(block_out_channels)
108
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
109
+ text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
110
+ num_blocks = len(block_out_channels) * [layers_per_block]
111
+ layer_params = [num_blocks, text_dims, add_self_attention]
112
+ rev_layer_params = map(reversed, layer_params)
113
+
114
+ cat_dims = []
115
+ self.num_levels = len(in_out_dims)
116
+ self.down_blocks = nn.ModuleList([])
117
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(
118
+ zip(in_out_dims, *layer_params)
119
+ ):
120
+ down_sample = level != (self.num_levels - 1)
121
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
122
+ self.down_blocks.append(
123
+ Kandinsky3DownSampleBlock(
124
+ in_dim,
125
+ out_dim,
126
+ time_embedding_dim,
127
+ text_dim,
128
+ res_block_num,
129
+ groups,
130
+ attention_head_dim,
131
+ expansion_ratio,
132
+ compression_ratio,
133
+ down_sample,
134
+ self_attention,
135
+ )
136
+ )
137
+
138
+ self.up_blocks = nn.ModuleList([])
139
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(
140
+ zip(reversed(in_out_dims), *rev_layer_params)
141
+ ):
142
+ up_sample = level != 0
143
+ self.up_blocks.append(
144
+ Kandinsky3UpSampleBlock(
145
+ in_dim,
146
+ cat_dims.pop(),
147
+ out_dim,
148
+ time_embedding_dim,
149
+ text_dim,
150
+ res_block_num,
151
+ groups,
152
+ attention_head_dim,
153
+ expansion_ratio,
154
+ compression_ratio,
155
+ up_sample,
156
+ self_attention,
157
+ )
158
+ )
159
+
160
+ self.conv_norm_out = nn.GroupNorm(groups, init_channels)
161
+ self.conv_act_out = nn.SiLU()
162
+ self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
163
+
164
+ @property
165
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
166
+ r"""
167
+ Returns:
168
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
169
+ indexed by its weight name.
170
+ """
171
+ # set recursively
172
+ processors = {}
173
+
174
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
175
+ if hasattr(module, "set_processor"):
176
+ processors[f"{name}.processor"] = module.processor
177
+
178
+ for sub_name, child in module.named_children():
179
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
180
+
181
+ return processors
182
+
183
+ for name, module in self.named_children():
184
+ fn_recursive_add_processors(name, module, processors)
185
+
186
+ return processors
187
+
188
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
+ r"""
190
+ Sets the attention processor to use to compute attention.
191
+
192
+ Parameters:
193
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
+ for **all** `Attention` layers.
196
+
197
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
+ processor. This is strongly recommended when setting trainable attention processors.
199
+
200
+ """
201
+ count = len(self.attn_processors.keys())
202
+
203
+ if isinstance(processor, dict) and len(processor) != count:
204
+ raise ValueError(
205
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
+ )
208
+
209
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
+ if hasattr(module, "set_processor"):
211
+ if not isinstance(processor, dict):
212
+ module.set_processor(processor)
213
+ else:
214
+ module.set_processor(processor.pop(f"{name}.processor"))
215
+
216
+ for sub_name, child in module.named_children():
217
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_attn_processor(name, module, processor)
221
+
222
+ def set_default_attn_processor(self):
223
+ """
224
+ Disables custom attention processors and sets the default attention implementation.
225
+ """
226
+ self.set_attn_processor(Kandi3AttnProcessor())
227
+
228
+ def _set_gradient_checkpointing(self, module, value=False):
229
+ if hasattr(module, "gradient_checkpointing"):
230
+ module.gradient_checkpointing = value
231
+
232
+ def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
233
+ # TODO(Yiyi): Clean up the following variables - these names should not be used
234
+ # but instead only the ones that we pass to forward
235
+ x = sample
236
+ context_mask = encoder_attention_mask
237
+ context = encoder_hidden_states
238
+
239
+ if not torch.is_tensor(timestep):
240
+ dtype = torch.float32 if isinstance(timestep, float) else torch.int32
241
+ timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
242
+ elif len(timestep.shape) == 0:
243
+ timestep = timestep[None].to(sample.device)
244
+
245
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
246
+ timestep = timestep.expand(sample.shape[0])
247
+ time_embed_input = self.time_proj(timestep).to(x.dtype)
248
+ time_embed = self.time_embedding(time_embed_input)
249
+
250
+ context = self.encoder_hid_proj(context)
251
+
252
+ if context is not None:
253
+ time_embed = self.add_time_condition(time_embed, context, context_mask)
254
+
255
+ hidden_states = []
256
+ x = self.conv_in(x)
257
+ for level, down_sample in enumerate(self.down_blocks):
258
+ x = down_sample(x, time_embed, context, context_mask)
259
+ if level != self.num_levels - 1:
260
+ hidden_states.append(x)
261
+
262
+ for level, up_sample in enumerate(self.up_blocks):
263
+ if level != 0:
264
+ x = torch.cat([x, hidden_states.pop()], dim=1)
265
+ x = up_sample(x, time_embed, context, context_mask)
266
+
267
+ x = self.conv_norm_out(x)
268
+ x = self.conv_act_out(x)
269
+ x = self.conv_out(x)
270
+
271
+ if not return_dict:
272
+ return (x,)
273
+ return Kandinsky3UNetOutput(sample=x)
274
+
275
+
276
+ class Kandinsky3UpSampleBlock(nn.Module):
277
+ def __init__(
278
+ self,
279
+ in_channels,
280
+ cat_dim,
281
+ out_channels,
282
+ time_embed_dim,
283
+ context_dim=None,
284
+ num_blocks=3,
285
+ groups=32,
286
+ head_dim=64,
287
+ expansion_ratio=4,
288
+ compression_ratio=2,
289
+ up_sample=True,
290
+ self_attention=True,
291
+ ):
292
+ super().__init__()
293
+ up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
294
+ hidden_channels = (
295
+ [(in_channels + cat_dim, in_channels)]
296
+ + [(in_channels, in_channels)] * (num_blocks - 2)
297
+ + [(in_channels, out_channels)]
298
+ )
299
+ attentions = []
300
+ resnets_in = []
301
+ resnets_out = []
302
+
303
+ self.self_attention = self_attention
304
+ self.context_dim = context_dim
305
+
306
+ attentions.append(
307
+ set_default_layer(
308
+ self_attention,
309
+ Kandinsky3AttentionBlock,
310
+ (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
311
+ layer_2=nn.Identity,
312
+ )
313
+ )
314
+
315
+ for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
316
+ resnets_in.append(
317
+ Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
318
+ )
319
+ attentions.append(
320
+ set_default_layer(
321
+ context_dim is not None,
322
+ Kandinsky3AttentionBlock,
323
+ (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
324
+ layer_2=nn.Identity,
325
+ )
326
+ )
327
+ resnets_out.append(
328
+ Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
329
+ )
330
+
331
+ self.attentions = nn.ModuleList(attentions)
332
+ self.resnets_in = nn.ModuleList(resnets_in)
333
+ self.resnets_out = nn.ModuleList(resnets_out)
334
+
335
+ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
336
+ for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
337
+ x = resnet_in(x, time_embed)
338
+ if self.context_dim is not None:
339
+ x = attention(x, time_embed, context, context_mask, image_mask)
340
+ x = resnet_out(x, time_embed)
341
+
342
+ if self.self_attention:
343
+ x = self.attentions[0](x, time_embed, image_mask=image_mask)
344
+ return x
345
+
346
+
347
+ class Kandinsky3DownSampleBlock(nn.Module):
348
+ def __init__(
349
+ self,
350
+ in_channels,
351
+ out_channels,
352
+ time_embed_dim,
353
+ context_dim=None,
354
+ num_blocks=3,
355
+ groups=32,
356
+ head_dim=64,
357
+ expansion_ratio=4,
358
+ compression_ratio=2,
359
+ down_sample=True,
360
+ self_attention=True,
361
+ ):
362
+ super().__init__()
363
+ attentions = []
364
+ resnets_in = []
365
+ resnets_out = []
366
+
367
+ self.self_attention = self_attention
368
+ self.context_dim = context_dim
369
+
370
+ attentions.append(
371
+ set_default_layer(
372
+ self_attention,
373
+ Kandinsky3AttentionBlock,
374
+ (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
375
+ layer_2=nn.Identity,
376
+ )
377
+ )
378
+
379
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
380
+ hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
381
+ for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
382
+ resnets_in.append(
383
+ Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
384
+ )
385
+ attentions.append(
386
+ set_default_layer(
387
+ context_dim is not None,
388
+ Kandinsky3AttentionBlock,
389
+ (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
390
+ layer_2=nn.Identity,
391
+ )
392
+ )
393
+ resnets_out.append(
394
+ Kandinsky3ResNetBlock(
395
+ out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
396
+ )
397
+ )
398
+
399
+ self.attentions = nn.ModuleList(attentions)
400
+ self.resnets_in = nn.ModuleList(resnets_in)
401
+ self.resnets_out = nn.ModuleList(resnets_out)
402
+
403
+ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
404
+ if self.self_attention:
405
+ x = self.attentions[0](x, time_embed, image_mask=image_mask)
406
+
407
+ for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
408
+ x = resnet_in(x, time_embed)
409
+ if self.context_dim is not None:
410
+ x = attention(x, time_embed, context, context_mask, image_mask)
411
+ x = resnet_out(x, time_embed)
412
+ return x
413
+
414
+
415
+ class Kandinsky3ConditionalGroupNorm(nn.Module):
416
+ def __init__(self, groups, normalized_shape, context_dim):
417
+ super().__init__()
418
+ self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
419
+ self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
420
+ self.context_mlp[1].weight.data.zero_()
421
+ self.context_mlp[1].bias.data.zero_()
422
+
423
+ def forward(self, x, context):
424
+ context = self.context_mlp(context)
425
+
426
+ for _ in range(len(x.shape[2:])):
427
+ context = context.unsqueeze(-1)
428
+
429
+ scale, shift = context.chunk(2, dim=1)
430
+ x = self.norm(x) * (scale + 1.0) + shift
431
+ return x
432
+
433
+
434
+ # TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
435
+ # sure we can delete it and instead just pass an attention_mask
436
+ class Attention(nn.Module):
437
+ def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
438
+ super().__init__()
439
+ assert out_channels % head_dim == 0
440
+ self.num_heads = out_channels // head_dim
441
+ self.scale = head_dim**-0.5
442
+
443
+ # to_q
444
+ self.to_q = nn.Linear(in_channels, out_channels, bias=False)
445
+ # to_k
446
+ self.to_k = nn.Linear(context_dim, out_channels, bias=False)
447
+ # to_v
448
+ self.to_v = nn.Linear(context_dim, out_channels, bias=False)
449
+ processor = Kandi3AttnProcessor()
450
+ self.set_processor(processor)
451
+ # to_out
452
+ self.to_out = nn.ModuleList([])
453
+ self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
454
+
455
+ def set_processor(self, processor: "AttnProcessor"): # noqa: F821
456
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
457
+ # pop `processor` from `self._modules`
458
+ if (
459
+ hasattr(self, "processor")
460
+ and isinstance(self.processor, torch.nn.Module)
461
+ and not isinstance(processor, torch.nn.Module)
462
+ ):
463
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
464
+ self._modules.pop("processor")
465
+
466
+ self.processor = processor
467
+
468
+ def forward(self, x, context, context_mask=None, image_mask=None):
469
+ return self.processor(
470
+ self,
471
+ x,
472
+ context=context,
473
+ context_mask=context_mask,
474
+ )
475
+
476
+
477
+ class Kandinsky3Block(nn.Module):
478
+ def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
479
+ super().__init__()
480
+ self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
481
+ self.activation = nn.SiLU()
482
+ self.up_sample = set_default_layer(
483
+ up_resolution is not None and up_resolution,
484
+ nn.ConvTranspose2d,
485
+ (in_channels, in_channels),
486
+ {"kernel_size": 2, "stride": 2},
487
+ )
488
+ padding = int(kernel_size > 1)
489
+ self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
490
+ self.down_sample = set_default_layer(
491
+ up_resolution is not None and not up_resolution,
492
+ nn.Conv2d,
493
+ (out_channels, out_channels),
494
+ {"kernel_size": 2, "stride": 2},
495
+ )
496
+
497
+ def forward(self, x, time_embed):
498
+ x = self.group_norm(x, time_embed)
499
+ x = self.activation(x)
500
+ x = self.up_sample(x)
501
+ x = self.projection(x)
502
+ x = self.down_sample(x)
503
+ return x
504
+
505
+
506
+ class Kandinsky3ResNetBlock(nn.Module):
507
+ def __init__(
508
+ self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
509
+ ):
510
+ super().__init__()
511
+ kernel_sizes = [1, 3, 3, 1]
512
+ hidden_channel = max(in_channels, out_channels) // compression_ratio
513
+ hidden_channels = (
514
+ [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
515
+ )
516
+ self.resnet_blocks = nn.ModuleList(
517
+ [
518
+ Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
519
+ for (in_channel, out_channel), kernel_size, up_resolution in zip(
520
+ hidden_channels, kernel_sizes, up_resolutions
521
+ )
522
+ ]
523
+ )
524
+ self.shortcut_up_sample = set_default_layer(
525
+ True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
526
+ )
527
+ self.shortcut_projection = set_default_layer(
528
+ in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
529
+ )
530
+ self.shortcut_down_sample = set_default_layer(
531
+ False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
532
+ )
533
+
534
+ def forward(self, x, time_embed):
535
+ out = x
536
+ for resnet_block in self.resnet_blocks:
537
+ out = resnet_block(out, time_embed)
538
+
539
+ x = self.shortcut_up_sample(x)
540
+ x = self.shortcut_projection(x)
541
+ x = self.shortcut_down_sample(x)
542
+ x = x + out
543
+ return x
544
+
545
+
546
+ class Kandinsky3AttentionPooling(nn.Module):
547
+ def __init__(self, num_channels, context_dim, head_dim=64):
548
+ super().__init__()
549
+ self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
550
+
551
+ def forward(self, x, context, context_mask=None):
552
+ context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
553
+ return x + context.squeeze(1)
554
+
555
+
556
+ class Kandinsky3AttentionBlock(nn.Module):
557
+ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
558
+ super().__init__()
559
+ self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
560
+ self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
561
+
562
+ hidden_channels = expansion_ratio * num_channels
563
+ self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
564
+ self.feed_forward = nn.Sequential(
565
+ nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
566
+ nn.SiLU(),
567
+ nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
568
+ )
569
+
570
+ def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
571
+ height, width = x.shape[-2:]
572
+ out = self.in_norm(x, time_embed)
573
+ out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
574
+ context = context if context is not None else out
575
+
576
+ if image_mask is not None:
577
+ mask_height, mask_width = image_mask.shape[-2:]
578
+ kernel_size = (mask_height // height, mask_width // width)
579
+ image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
580
+ image_mask = image_mask.reshape(image_mask.shape[0], -1)
581
+
582
+ out = self.attention(out, context, context_mask, image_mask)
583
+ out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
584
+ x = x + out
585
+
586
+ out = self.out_norm(x, time_embed)
587
+ out = self.feed_forward(out)
588
+ x = x + out
589
+ return x