diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
33
33
  Dropout rate
34
34
  num_layers (:obj:`int`, *optional*, defaults to 1):
35
35
  Number of attention blocks layers
36
- attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
36
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
37
37
  Number of attention heads of each spatial transformer block
38
38
  add_downsample (:obj:`bool`, *optional*, defaults to `True`):
39
39
  Whether to add downsampling layer before each final output
@@ -46,7 +46,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
46
46
  out_channels: int
47
47
  dropout: float = 0.0
48
48
  num_layers: int = 1
49
- attn_num_head_channels: int = 1
49
+ num_attention_heads: int = 1
50
50
  add_downsample: bool = True
51
51
  use_linear_projection: bool = False
52
52
  only_cross_attention: bool = False
@@ -70,8 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
70
70
 
71
71
  attn_block = FlaxTransformer2DModel(
72
72
  in_channels=self.out_channels,
73
- n_heads=self.attn_num_head_channels,
74
- d_head=self.out_channels // self.attn_num_head_channels,
73
+ n_heads=self.num_attention_heads,
74
+ d_head=self.out_channels // self.num_attention_heads,
75
75
  depth=1,
76
76
  use_linear_projection=self.use_linear_projection,
77
77
  only_cross_attention=self.only_cross_attention,
@@ -172,7 +172,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
172
172
  Dropout rate
173
173
  num_layers (:obj:`int`, *optional*, defaults to 1):
174
174
  Number of attention blocks layers
175
- attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
175
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
176
176
  Number of attention heads of each spatial transformer block
177
177
  add_upsample (:obj:`bool`, *optional*, defaults to `True`):
178
178
  Whether to add upsampling layer before each final output
@@ -186,7 +186,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
186
186
  prev_output_channel: int
187
187
  dropout: float = 0.0
188
188
  num_layers: int = 1
189
- attn_num_head_channels: int = 1
189
+ num_attention_heads: int = 1
190
190
  add_upsample: bool = True
191
191
  use_linear_projection: bool = False
192
192
  only_cross_attention: bool = False
@@ -211,8 +211,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
211
211
 
212
212
  attn_block = FlaxTransformer2DModel(
213
213
  in_channels=self.out_channels,
214
- n_heads=self.attn_num_head_channels,
215
- d_head=self.out_channels // self.attn_num_head_channels,
214
+ n_heads=self.num_attention_heads,
215
+ d_head=self.out_channels // self.num_attention_heads,
216
216
  depth=1,
217
217
  use_linear_projection=self.use_linear_projection,
218
218
  only_cross_attention=self.only_cross_attention,
@@ -317,7 +317,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
317
317
  Dropout rate
318
318
  num_layers (:obj:`int`, *optional*, defaults to 1):
319
319
  Number of attention blocks layers
320
- attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
320
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
321
321
  Number of attention heads of each spatial transformer block
322
322
  use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
323
323
  enable memory efficient attention https://arxiv.org/abs/2112.05682
@@ -327,7 +327,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
327
327
  in_channels: int
328
328
  dropout: float = 0.0
329
329
  num_layers: int = 1
330
- attn_num_head_channels: int = 1
330
+ num_attention_heads: int = 1
331
331
  use_linear_projection: bool = False
332
332
  use_memory_efficient_attention: bool = False
333
333
  dtype: jnp.dtype = jnp.float32
@@ -348,8 +348,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
348
348
  for _ in range(self.num_layers):
349
349
  attn_block = FlaxTransformer2DModel(
350
350
  in_channels=self.in_channels,
351
- n_heads=self.attn_num_head_channels,
352
- d_head=self.in_channels // self.attn_num_head_channels,
351
+ n_heads=self.num_attention_heads,
352
+ d_head=self.in_channels // self.num_attention_heads,
353
353
  depth=1,
354
354
  use_linear_projection=self.use_linear_projection,
355
355
  use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -25,6 +25,9 @@ from .activations import get_activation
25
25
  from .attention_processor import AttentionProcessor, AttnProcessor
26
26
  from .embeddings import (
27
27
  GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
28
31
  TextImageProjection,
29
32
  TextImageTimeEmbedding,
30
33
  TextTimeEmbedding,
@@ -50,27 +53,29 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
53
  @dataclass
51
54
  class UNet2DConditionOutput(BaseOutput):
52
55
  """
56
+ The output of [`UNet2DConditionModel`].
57
+
53
58
  Args:
54
59
  sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
55
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
60
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
56
61
  """
57
62
 
58
- sample: torch.FloatTensor
63
+ sample: torch.FloatTensor = None
59
64
 
60
65
 
61
66
  class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
62
67
  r"""
63
- UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
64
- and returns sample shaped output.
68
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
69
+ shaped output.
65
70
 
66
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
67
- implements for all the models (such as downloading or saving, etc.)
71
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
72
+ for all models (such as downloading or saving).
68
73
 
69
74
  Parameters:
70
75
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
71
76
  Height and width of input/output sample.
72
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
73
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
77
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
78
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
74
79
  center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
75
80
  flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
76
81
  Whether to flip the sin to cos in the time embedding.
@@ -78,9 +83,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
78
83
  down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
79
84
  The tuple of downsample blocks to use.
80
85
  mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
81
- The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
82
- mid block layer if `None`.
83
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
86
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
87
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
88
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
84
89
  The tuple of upsample blocks to use.
85
90
  only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
86
91
  Whether to include self-attention in the basic transformer blocks, see
@@ -92,50 +97,58 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
92
97
  mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
93
98
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
94
99
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
95
- If `None`, it will skip the normalization and activation layers in post-processing
100
+ If `None`, normalization and activation layers is skipped in post-processing.
96
101
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
97
102
  cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
98
103
  The dimension of the cross attention features.
104
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
105
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
106
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
107
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
99
108
  encoder_hid_dim (`int`, *optional*, defaults to None):
100
109
  If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
101
110
  dimension to `cross_attention_dim`.
102
- encoder_hid_dim_type (`str`, *optional*, defaults to None):
103
- If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
111
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
112
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
104
113
  embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
105
114
  attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
115
+ num_attention_heads (`int`, *optional*):
116
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
106
117
  resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
107
- for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
108
- class_embed_type (`str`, *optional*, defaults to None):
118
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
119
+ class_embed_type (`str`, *optional*, defaults to `None`):
109
120
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
110
121
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
111
- addition_embed_type (`str`, *optional*, defaults to None):
122
+ addition_embed_type (`str`, *optional*, defaults to `None`):
112
123
  Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
113
124
  "text". "text" will use the `TextTimeEmbedding` layer.
114
- num_class_embeds (`int`, *optional*, defaults to None):
125
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
126
+ Dimension for the timestep embeddings.
127
+ num_class_embeds (`int`, *optional*, defaults to `None`):
115
128
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
116
129
  class conditioning with `class_embed_type` equal to `None`.
117
- time_embedding_type (`str`, *optional*, default to `positional`):
130
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
118
131
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
119
- time_embedding_dim (`int`, *optional*, default to `None`):
132
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
120
133
  An optional override for the dimension of the projected time embedding.
121
- time_embedding_act_fn (`str`, *optional*, default to `None`):
122
- Optional activation function to use on the time embeddings only one time before they as passed to the rest
123
- of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
124
- timestep_post_act (`str, *optional*, default to `None`):
134
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
135
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
136
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
137
+ timestep_post_act (`str`, *optional*, defaults to `None`):
125
138
  The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
126
- time_cond_proj_dim (`int`, *optional*, default to `None`):
127
- The dimension of `cond_proj` layer in timestep embedding.
139
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
140
+ The dimension of `cond_proj` layer in the timestep embedding.
128
141
  conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
129
142
  conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
130
143
  projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
131
- using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
144
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
132
145
  class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
133
146
  embeddings with the class embeddings.
134
147
  mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
135
148
  Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
136
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
137
- `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
138
- default to `False`.
149
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
150
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
151
+ otherwise.
139
152
  """
140
153
 
141
154
  _supports_gradient_checkpointing = True
@@ -166,13 +179,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
166
179
  norm_num_groups: Optional[int] = 32,
167
180
  norm_eps: float = 1e-5,
168
181
  cross_attention_dim: Union[int, Tuple[int]] = 1280,
182
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
169
183
  encoder_hid_dim: Optional[int] = None,
170
184
  encoder_hid_dim_type: Optional[str] = None,
171
185
  attention_head_dim: Union[int, Tuple[int]] = 8,
186
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
172
187
  dual_cross_attention: bool = False,
173
188
  use_linear_projection: bool = False,
174
189
  class_embed_type: Optional[str] = None,
175
190
  addition_embed_type: Optional[str] = None,
191
+ addition_time_embed_dim: Optional[int] = None,
176
192
  num_class_embeds: Optional[int] = None,
177
193
  upcast_attention: bool = False,
178
194
  resnet_time_scale_shift: str = "default",
@@ -195,6 +211,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
195
211
 
196
212
  self.sample_size = sample_size
197
213
 
214
+ if num_attention_heads is not None:
215
+ raise ValueError(
216
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
217
+ )
218
+
219
+ # If `num_attention_heads` is not defined (which is the case for most models)
220
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
221
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
222
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
223
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
224
+ # which is why we correct for the naming here.
225
+ num_attention_heads = num_attention_heads or attention_head_dim
226
+
198
227
  # Check inputs
199
228
  if len(down_block_types) != len(up_block_types):
200
229
  raise ValueError(
@@ -211,6 +240,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
211
240
  f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
212
241
  )
213
242
 
243
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
244
+ raise ValueError(
245
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
246
+ )
247
+
214
248
  if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
215
249
  raise ValueError(
216
250
  f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
@@ -280,7 +314,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
280
314
  image_embed_dim=cross_attention_dim,
281
315
  cross_attention_dim=cross_attention_dim,
282
316
  )
283
-
317
+ elif encoder_hid_dim_type == "image_proj":
318
+ # Kandinsky 2.2
319
+ self.encoder_hid_proj = ImageProjection(
320
+ image_embed_dim=encoder_hid_dim,
321
+ cross_attention_dim=cross_attention_dim,
322
+ )
284
323
  elif encoder_hid_dim_type is not None:
285
324
  raise ValueError(
286
325
  f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
@@ -333,6 +372,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
333
372
  self.add_embedding = TextImageTimeEmbedding(
334
373
  text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
335
374
  )
375
+ elif addition_embed_type == "text_time":
376
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+ elif addition_embed_type == "image":
379
+ # Kandinsky 2.2
380
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
381
+ elif addition_embed_type == "image_hint":
382
+ # Kandinsky 2.2 ControlNet
383
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
336
384
  elif addition_embed_type is not None:
337
385
  raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
338
386
 
@@ -353,6 +401,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
353
401
  if mid_block_only_cross_attention is None:
354
402
  mid_block_only_cross_attention = False
355
403
 
404
+ if isinstance(num_attention_heads, int):
405
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
406
+
356
407
  if isinstance(attention_head_dim, int):
357
408
  attention_head_dim = (attention_head_dim,) * len(down_block_types)
358
409
 
@@ -362,6 +413,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
362
413
  if isinstance(layers_per_block, int):
363
414
  layers_per_block = [layers_per_block] * len(down_block_types)
364
415
 
416
+ if isinstance(transformer_layers_per_block, int):
417
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
418
+
365
419
  if class_embeddings_concat:
366
420
  # The time embeddings are concatenated with the class embeddings. The dimension of the
367
421
  # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
@@ -380,6 +434,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
380
434
  down_block = get_down_block(
381
435
  down_block_type,
382
436
  num_layers=layers_per_block[i],
437
+ transformer_layers_per_block=transformer_layers_per_block[i],
383
438
  in_channels=input_channel,
384
439
  out_channels=output_channel,
385
440
  temb_channels=blocks_time_embed_dim,
@@ -388,7 +443,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
388
443
  resnet_act_fn=act_fn,
389
444
  resnet_groups=norm_num_groups,
390
445
  cross_attention_dim=cross_attention_dim[i],
391
- attn_num_head_channels=attention_head_dim[i],
446
+ num_attention_heads=num_attention_heads[i],
392
447
  downsample_padding=downsample_padding,
393
448
  dual_cross_attention=dual_cross_attention,
394
449
  use_linear_projection=use_linear_projection,
@@ -398,12 +453,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
398
453
  resnet_skip_time_act=resnet_skip_time_act,
399
454
  resnet_out_scale_factor=resnet_out_scale_factor,
400
455
  cross_attention_norm=cross_attention_norm,
456
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
401
457
  )
402
458
  self.down_blocks.append(down_block)
403
459
 
404
460
  # mid
405
461
  if mid_block_type == "UNetMidBlock2DCrossAttn":
406
462
  self.mid_block = UNetMidBlock2DCrossAttn(
463
+ transformer_layers_per_block=transformer_layers_per_block[-1],
407
464
  in_channels=block_out_channels[-1],
408
465
  temb_channels=blocks_time_embed_dim,
409
466
  resnet_eps=norm_eps,
@@ -411,7 +468,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
411
468
  output_scale_factor=mid_block_scale_factor,
412
469
  resnet_time_scale_shift=resnet_time_scale_shift,
413
470
  cross_attention_dim=cross_attention_dim[-1],
414
- attn_num_head_channels=attention_head_dim[-1],
471
+ num_attention_heads=num_attention_heads[-1],
415
472
  resnet_groups=norm_num_groups,
416
473
  dual_cross_attention=dual_cross_attention,
417
474
  use_linear_projection=use_linear_projection,
@@ -425,7 +482,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
425
482
  resnet_act_fn=act_fn,
426
483
  output_scale_factor=mid_block_scale_factor,
427
484
  cross_attention_dim=cross_attention_dim[-1],
428
- attn_num_head_channels=attention_head_dim[-1],
485
+ attention_head_dim=attention_head_dim[-1],
429
486
  resnet_groups=norm_num_groups,
430
487
  resnet_time_scale_shift=resnet_time_scale_shift,
431
488
  skip_time_act=resnet_skip_time_act,
@@ -442,9 +499,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
442
499
 
443
500
  # up
444
501
  reversed_block_out_channels = list(reversed(block_out_channels))
445
- reversed_attention_head_dim = list(reversed(attention_head_dim))
502
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
446
503
  reversed_layers_per_block = list(reversed(layers_per_block))
447
504
  reversed_cross_attention_dim = list(reversed(cross_attention_dim))
505
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
448
506
  only_cross_attention = list(reversed(only_cross_attention))
449
507
 
450
508
  output_channel = reversed_block_out_channels[0]
@@ -465,6 +523,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
465
523
  up_block = get_up_block(
466
524
  up_block_type,
467
525
  num_layers=reversed_layers_per_block[i] + 1,
526
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
468
527
  in_channels=input_channel,
469
528
  out_channels=output_channel,
470
529
  prev_output_channel=prev_output_channel,
@@ -474,7 +533,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
474
533
  resnet_act_fn=act_fn,
475
534
  resnet_groups=norm_num_groups,
476
535
  cross_attention_dim=reversed_cross_attention_dim[i],
477
- attn_num_head_channels=reversed_attention_head_dim[i],
536
+ num_attention_heads=reversed_num_attention_heads[i],
478
537
  dual_cross_attention=dual_cross_attention,
479
538
  use_linear_projection=use_linear_projection,
480
539
  only_cross_attention=only_cross_attention[i],
@@ -483,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
483
542
  resnet_skip_time_act=resnet_skip_time_act,
484
543
  resnet_out_scale_factor=resnet_out_scale_factor,
485
544
  cross_attention_norm=cross_attention_norm,
545
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
486
546
  )
487
547
  self.up_blocks.append(up_block)
488
548
  prev_output_channel = output_channel
@@ -530,11 +590,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
530
590
 
531
591
  def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
532
592
  r"""
593
+ Sets the attention processor to use to compute attention.
594
+
533
595
  Parameters:
534
- `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
596
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
535
597
  The instantiated processor class or a dictionary of processor classes that will be set as the processor
536
- of **all** `Attention` layers.
537
- In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
598
+ for **all** `Attention` layers.
599
+
600
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
601
+ processor. This is strongly recommended when setting trainable attention processors.
538
602
 
539
603
  """
540
604
  count = len(self.attn_processors.keys())
@@ -568,13 +632,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
568
632
  r"""
569
633
  Enable sliced attention computation.
570
634
 
571
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
572
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
635
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
636
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
573
637
 
574
638
  Args:
575
639
  slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
576
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
577
- `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
640
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
641
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
578
642
  provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
579
643
  must be a multiple of `slice_size`.
580
644
  """
@@ -649,29 +713,31 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
649
713
  return_dict: bool = True,
650
714
  ) -> Union[UNet2DConditionOutput, Tuple]:
651
715
  r"""
716
+ The [`UNet2DConditionModel`] forward method.
717
+
652
718
  Args:
653
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
654
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
655
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
719
+ sample (`torch.FloatTensor`):
720
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
721
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
722
+ encoder_hidden_states (`torch.FloatTensor`):
723
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
656
724
  encoder_attention_mask (`torch.Tensor`):
657
- (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
658
- discard. Mask will be converted into a bias, which adds large negative values to attention scores
659
- corresponding to "discard" tokens.
725
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
726
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
727
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
660
728
  return_dict (`bool`, *optional*, defaults to `True`):
661
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
729
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
730
+ tuple.
662
731
  cross_attention_kwargs (`dict`, *optional*):
663
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
664
- `self.processor` in
665
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
666
- added_cond_kwargs (`dict`, *optional*):
667
- A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
668
- embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
669
- `addition_embed_type` for more information.
732
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
733
+ added_cond_kwargs: (`dict`, *optional*):
734
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
735
+ are passed along to the UNet blocks.
670
736
 
671
737
  Returns:
672
738
  [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
673
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
674
- returning a tuple, the first element is the sample tensor.
739
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
740
+ a `tuple` is returned where the first element is the sample tensor.
675
741
  """
676
742
  # By default samples have to be AT least a multiple of the overall upsampling factor.
677
743
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -737,6 +803,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
737
803
  t_emb = t_emb.to(dtype=sample.dtype)
738
804
 
739
805
  emb = self.time_embedding(t_emb, timestep_cond)
806
+ aug_emb = None
740
807
 
741
808
  if self.class_embedding is not None:
742
809
  if class_labels is None:
@@ -758,9 +825,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
758
825
 
759
826
  if self.config.addition_embed_type == "text":
760
827
  aug_emb = self.add_embedding(encoder_hidden_states)
761
- emb = emb + aug_emb
762
828
  elif self.config.addition_embed_type == "text_image":
763
- # Kadinsky 2.1 - style
829
+ # Kandinsky 2.1 - style
764
830
  if "image_embeds" not in added_cond_kwargs:
765
831
  raise ValueError(
766
832
  f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
@@ -768,9 +834,44 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
768
834
 
769
835
  image_embs = added_cond_kwargs.get("image_embeds")
770
836
  text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
771
-
772
837
  aug_emb = self.add_embedding(text_embs, image_embs)
773
- emb = emb + aug_emb
838
+ elif self.config.addition_embed_type == "text_time":
839
+ if "text_embeds" not in added_cond_kwargs:
840
+ raise ValueError(
841
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
842
+ )
843
+ text_embeds = added_cond_kwargs.get("text_embeds")
844
+ if "time_ids" not in added_cond_kwargs:
845
+ raise ValueError(
846
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
847
+ )
848
+ time_ids = added_cond_kwargs.get("time_ids")
849
+ time_embeds = self.add_time_proj(time_ids.flatten())
850
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
851
+
852
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
853
+ add_embeds = add_embeds.to(emb.dtype)
854
+ aug_emb = self.add_embedding(add_embeds)
855
+ elif self.config.addition_embed_type == "image":
856
+ # Kandinsky 2.2 - style
857
+ if "image_embeds" not in added_cond_kwargs:
858
+ raise ValueError(
859
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
860
+ )
861
+ image_embs = added_cond_kwargs.get("image_embeds")
862
+ aug_emb = self.add_embedding(image_embs)
863
+ elif self.config.addition_embed_type == "image_hint":
864
+ # Kandinsky 2.2 - style
865
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
866
+ raise ValueError(
867
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
868
+ )
869
+ image_embs = added_cond_kwargs.get("image_embeds")
870
+ hint = added_cond_kwargs.get("hint")
871
+ aug_emb, hint = self.add_embedding(image_embs, hint)
872
+ sample = torch.cat([sample, hint], dim=1)
873
+
874
+ emb = emb + aug_emb if aug_emb is not None else emb
774
875
 
775
876
  if self.time_embed_act is not None:
776
877
  emb = self.time_embed_act(emb)
@@ -786,7 +887,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
786
887
 
787
888
  image_embeds = added_cond_kwargs.get("image_embeds")
788
889
  encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
789
-
890
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
891
+ # Kandinsky 2.2 - style
892
+ if "image_embeds" not in added_cond_kwargs:
893
+ raise ValueError(
894
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
895
+ )
896
+ image_embeds = added_cond_kwargs.get("image_embeds")
897
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
790
898
  # 2. pre-process
791
899
  sample = self.conv_in(sample)
792
900