diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,9 @@ from ...models.attention_processor import (
18
18
  from ...models.dual_transformer_2d import DualTransformer2DModel
19
19
  from ...models.embeddings import (
20
20
  GaussianFourierProjection,
21
+ ImageHintTimeEmbedding,
22
+ ImageProjection,
23
+ ImageTimeEmbedding,
21
24
  TextImageProjection,
22
25
  TextImageTimeEmbedding,
23
26
  TextTimeEmbedding,
@@ -41,7 +44,7 @@ def get_down_block(
41
44
  add_downsample,
42
45
  resnet_eps,
43
46
  resnet_act_fn,
44
- attn_num_head_channels,
47
+ num_attention_heads,
45
48
  resnet_groups=None,
46
49
  cross_attention_dim=None,
47
50
  downsample_padding=None,
@@ -82,7 +85,7 @@ def get_down_block(
82
85
  resnet_groups=resnet_groups,
83
86
  downsample_padding=downsample_padding,
84
87
  cross_attention_dim=cross_attention_dim,
85
- attn_num_head_channels=attn_num_head_channels,
88
+ num_attention_heads=num_attention_heads,
86
89
  dual_cross_attention=dual_cross_attention,
87
90
  use_linear_projection=use_linear_projection,
88
91
  only_cross_attention=only_cross_attention,
@@ -101,7 +104,7 @@ def get_up_block(
101
104
  add_upsample,
102
105
  resnet_eps,
103
106
  resnet_act_fn,
104
- attn_num_head_channels,
107
+ num_attention_heads,
105
108
  resnet_groups=None,
106
109
  cross_attention_dim=None,
107
110
  dual_cross_attention=False,
@@ -141,7 +144,7 @@ def get_up_block(
141
144
  resnet_act_fn=resnet_act_fn,
142
145
  resnet_groups=resnet_groups,
143
146
  cross_attention_dim=cross_attention_dim,
144
- attn_num_head_channels=attn_num_head_channels,
147
+ num_attention_heads=num_attention_heads,
145
148
  dual_cross_attention=dual_cross_attention,
146
149
  use_linear_projection=use_linear_projection,
147
150
  only_cross_attention=only_cross_attention,
@@ -153,17 +156,17 @@ def get_up_block(
153
156
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
154
157
  class UNetFlatConditionModel(ModelMixin, ConfigMixin):
155
158
  r"""
156
- UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
157
- timestep and returns sample shaped output.
159
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
160
+ shaped output.
158
161
 
159
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
160
- implements for all the models (such as downloading or saving, etc.)
162
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
163
+ for all models (such as downloading or saving).
161
164
 
162
165
  Parameters:
163
166
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
164
167
  Height and width of input/output sample.
165
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
166
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
168
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
169
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
167
170
  center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
168
171
  flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
169
172
  Whether to flip the sin to cos in the time embedding.
@@ -171,9 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
171
174
  down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
172
175
  The tuple of downsample blocks to use.
173
176
  mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
174
- The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip
175
- the mid block layer if `None`.
176
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
177
+ Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
178
+ `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
179
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
177
180
  The tuple of upsample blocks to use.
178
181
  only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
179
182
  Whether to include self-attention in the basic transformer blocks, see
@@ -185,50 +188,58 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
185
188
  mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
186
189
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
187
190
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
188
- If `None`, it will skip the normalization and activation layers in post-processing
191
+ If `None`, normalization and activation layers is skipped in post-processing.
189
192
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
190
193
  cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
191
194
  The dimension of the cross attention features.
195
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
196
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
197
+ [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
198
+ [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
192
199
  encoder_hid_dim (`int`, *optional*, defaults to None):
193
200
  If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
194
201
  dimension to `cross_attention_dim`.
195
- encoder_hid_dim_type (`str`, *optional*, defaults to None):
196
- If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
202
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
203
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
197
204
  embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
198
205
  attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
206
+ num_attention_heads (`int`, *optional*):
207
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
199
208
  resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
200
- for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
201
- class_embed_type (`str`, *optional*, defaults to None):
209
+ for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`.
210
+ class_embed_type (`str`, *optional*, defaults to `None`):
202
211
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
203
212
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
204
- addition_embed_type (`str`, *optional*, defaults to None):
213
+ addition_embed_type (`str`, *optional*, defaults to `None`):
205
214
  Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
206
215
  "text". "text" will use the `TextTimeEmbedding` layer.
207
- num_class_embeds (`int`, *optional*, defaults to None):
216
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
217
+ Dimension for the timestep embeddings.
218
+ num_class_embeds (`int`, *optional*, defaults to `None`):
208
219
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
209
220
  class conditioning with `class_embed_type` equal to `None`.
210
- time_embedding_type (`str`, *optional*, default to `positional`):
221
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
211
222
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
212
- time_embedding_dim (`int`, *optional*, default to `None`):
223
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
213
224
  An optional override for the dimension of the projected time embedding.
214
- time_embedding_act_fn (`str`, *optional*, default to `None`):
215
- Optional activation function to use on the time embeddings only one time before they as passed to the rest
216
- of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
217
- timestep_post_act (`str, *optional*, default to `None`):
225
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
226
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
227
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
228
+ timestep_post_act (`str`, *optional*, defaults to `None`):
218
229
  The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
219
- time_cond_proj_dim (`int`, *optional*, default to `None`):
220
- The dimension of `cond_proj` layer in timestep embedding.
230
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
231
+ The dimension of `cond_proj` layer in the timestep embedding.
221
232
  conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
222
233
  conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
223
234
  projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
224
- using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
235
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
225
236
  class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
226
237
  embeddings with the class embeddings.
227
238
  mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
228
239
  Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
229
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
230
- `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
231
- default to `False`.
240
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
241
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
242
+ otherwise.
232
243
  """
233
244
 
234
245
  _supports_gradient_checkpointing = True
@@ -264,13 +275,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
264
275
  norm_num_groups: Optional[int] = 32,
265
276
  norm_eps: float = 1e-5,
266
277
  cross_attention_dim: Union[int, Tuple[int]] = 1280,
278
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
267
279
  encoder_hid_dim: Optional[int] = None,
268
280
  encoder_hid_dim_type: Optional[str] = None,
269
281
  attention_head_dim: Union[int, Tuple[int]] = 8,
282
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
270
283
  dual_cross_attention: bool = False,
271
284
  use_linear_projection: bool = False,
272
285
  class_embed_type: Optional[str] = None,
273
286
  addition_embed_type: Optional[str] = None,
287
+ addition_time_embed_dim: Optional[int] = None,
274
288
  num_class_embeds: Optional[int] = None,
275
289
  upcast_attention: bool = False,
276
290
  resnet_time_scale_shift: str = "default",
@@ -293,6 +307,22 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
293
307
 
294
308
  self.sample_size = sample_size
295
309
 
310
+ if num_attention_heads is not None:
311
+ raise ValueError(
312
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
313
+ " because of a naming issue as described in"
314
+ " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
315
+ " `num_attention_heads` will only be supported in diffusers v0.19."
316
+ )
317
+
318
+ # If `num_attention_heads` is not defined (which is the case for most models)
319
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
320
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
321
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
322
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
323
+ # which is why we correct for the naming here.
324
+ num_attention_heads = num_attention_heads or attention_head_dim
325
+
296
326
  # Check inputs
297
327
  if len(down_block_types) != len(up_block_types):
298
328
  raise ValueError(
@@ -312,6 +342,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
312
342
  f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
313
343
  )
314
344
 
345
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
346
+ raise ValueError(
347
+ "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
348
+ f" {num_attention_heads}. `down_block_types`: {down_block_types}."
349
+ )
350
+
315
351
  if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
316
352
  raise ValueError(
317
353
  "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
@@ -384,7 +420,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
384
420
  image_embed_dim=cross_attention_dim,
385
421
  cross_attention_dim=cross_attention_dim,
386
422
  )
387
-
423
+ elif encoder_hid_dim_type == "image_proj":
424
+ # Kandinsky 2.2
425
+ self.encoder_hid_proj = ImageProjection(
426
+ image_embed_dim=encoder_hid_dim,
427
+ cross_attention_dim=cross_attention_dim,
428
+ )
388
429
  elif encoder_hid_dim_type is not None:
389
430
  raise ValueError(
390
431
  f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
@@ -437,6 +478,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
437
478
  self.add_embedding = TextImageTimeEmbedding(
438
479
  text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
439
480
  )
481
+ elif addition_embed_type == "text_time":
482
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
483
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
484
+ elif addition_embed_type == "image":
485
+ # Kandinsky 2.2
486
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
487
+ elif addition_embed_type == "image_hint":
488
+ # Kandinsky 2.2 ControlNet
489
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
440
490
  elif addition_embed_type is not None:
441
491
  raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
442
492
 
@@ -457,6 +507,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
457
507
  if mid_block_only_cross_attention is None:
458
508
  mid_block_only_cross_attention = False
459
509
 
510
+ if isinstance(num_attention_heads, int):
511
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
512
+
460
513
  if isinstance(attention_head_dim, int):
461
514
  attention_head_dim = (attention_head_dim,) * len(down_block_types)
462
515
 
@@ -466,6 +519,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
466
519
  if isinstance(layers_per_block, int):
467
520
  layers_per_block = [layers_per_block] * len(down_block_types)
468
521
 
522
+ if isinstance(transformer_layers_per_block, int):
523
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
524
+
469
525
  if class_embeddings_concat:
470
526
  # The time embeddings are concatenated with the class embeddings. The dimension of the
471
527
  # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
@@ -484,6 +540,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
484
540
  down_block = get_down_block(
485
541
  down_block_type,
486
542
  num_layers=layers_per_block[i],
543
+ transformer_layers_per_block=transformer_layers_per_block[i],
487
544
  in_channels=input_channel,
488
545
  out_channels=output_channel,
489
546
  temb_channels=blocks_time_embed_dim,
@@ -492,7 +549,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
492
549
  resnet_act_fn=act_fn,
493
550
  resnet_groups=norm_num_groups,
494
551
  cross_attention_dim=cross_attention_dim[i],
495
- attn_num_head_channels=attention_head_dim[i],
552
+ num_attention_heads=num_attention_heads[i],
496
553
  downsample_padding=downsample_padding,
497
554
  dual_cross_attention=dual_cross_attention,
498
555
  use_linear_projection=use_linear_projection,
@@ -502,12 +559,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
502
559
  resnet_skip_time_act=resnet_skip_time_act,
503
560
  resnet_out_scale_factor=resnet_out_scale_factor,
504
561
  cross_attention_norm=cross_attention_norm,
562
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
505
563
  )
506
564
  self.down_blocks.append(down_block)
507
565
 
508
566
  # mid
509
567
  if mid_block_type == "UNetMidBlockFlatCrossAttn":
510
568
  self.mid_block = UNetMidBlockFlatCrossAttn(
569
+ transformer_layers_per_block=transformer_layers_per_block[-1],
511
570
  in_channels=block_out_channels[-1],
512
571
  temb_channels=blocks_time_embed_dim,
513
572
  resnet_eps=norm_eps,
@@ -515,7 +574,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
515
574
  output_scale_factor=mid_block_scale_factor,
516
575
  resnet_time_scale_shift=resnet_time_scale_shift,
517
576
  cross_attention_dim=cross_attention_dim[-1],
518
- attn_num_head_channels=attention_head_dim[-1],
577
+ num_attention_heads=num_attention_heads[-1],
519
578
  resnet_groups=norm_num_groups,
520
579
  dual_cross_attention=dual_cross_attention,
521
580
  use_linear_projection=use_linear_projection,
@@ -529,7 +588,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
529
588
  resnet_act_fn=act_fn,
530
589
  output_scale_factor=mid_block_scale_factor,
531
590
  cross_attention_dim=cross_attention_dim[-1],
532
- attn_num_head_channels=attention_head_dim[-1],
591
+ attention_head_dim=attention_head_dim[-1],
533
592
  resnet_groups=norm_num_groups,
534
593
  resnet_time_scale_shift=resnet_time_scale_shift,
535
594
  skip_time_act=resnet_skip_time_act,
@@ -546,9 +605,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
546
605
 
547
606
  # up
548
607
  reversed_block_out_channels = list(reversed(block_out_channels))
549
- reversed_attention_head_dim = list(reversed(attention_head_dim))
608
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
550
609
  reversed_layers_per_block = list(reversed(layers_per_block))
551
610
  reversed_cross_attention_dim = list(reversed(cross_attention_dim))
611
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
552
612
  only_cross_attention = list(reversed(only_cross_attention))
553
613
 
554
614
  output_channel = reversed_block_out_channels[0]
@@ -569,6 +629,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
569
629
  up_block = get_up_block(
570
630
  up_block_type,
571
631
  num_layers=reversed_layers_per_block[i] + 1,
632
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
572
633
  in_channels=input_channel,
573
634
  out_channels=output_channel,
574
635
  prev_output_channel=prev_output_channel,
@@ -578,7 +639,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
578
639
  resnet_act_fn=act_fn,
579
640
  resnet_groups=norm_num_groups,
580
641
  cross_attention_dim=reversed_cross_attention_dim[i],
581
- attn_num_head_channels=reversed_attention_head_dim[i],
642
+ num_attention_heads=reversed_num_attention_heads[i],
582
643
  dual_cross_attention=dual_cross_attention,
583
644
  use_linear_projection=use_linear_projection,
584
645
  only_cross_attention=only_cross_attention[i],
@@ -587,6 +648,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
587
648
  resnet_skip_time_act=resnet_skip_time_act,
588
649
  resnet_out_scale_factor=resnet_out_scale_factor,
589
650
  cross_attention_norm=cross_attention_norm,
651
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
590
652
  )
591
653
  self.up_blocks.append(up_block)
592
654
  prev_output_channel = output_channel
@@ -634,11 +696,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
634
696
 
635
697
  def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
636
698
  r"""
699
+ Sets the attention processor to use to compute attention.
700
+
637
701
  Parameters:
638
- `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
702
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
639
703
  The instantiated processor class or a dictionary of processor classes that will be set as the processor
640
- of **all** `Attention` layers.
641
- In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
704
+ for **all** `Attention` layers.
705
+
706
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
707
+ processor. This is strongly recommended when setting trainable attention processors.
642
708
 
643
709
  """
644
710
  count = len(self.attn_processors.keys())
@@ -672,13 +738,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
672
738
  r"""
673
739
  Enable sliced attention computation.
674
740
 
675
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
676
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
741
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
742
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
677
743
 
678
744
  Args:
679
745
  slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
680
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
681
- `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
746
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
747
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
682
748
  provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
683
749
  must be a multiple of `slice_size`.
684
750
  """
@@ -753,29 +819,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
753
819
  return_dict: bool = True,
754
820
  ) -> Union[UNet2DConditionOutput, Tuple]:
755
821
  r"""
822
+ The [`UNetFlatConditionModel`] forward method.
823
+
756
824
  Args:
757
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
758
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
759
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
825
+ sample (`torch.FloatTensor`):
826
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
827
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
828
+ encoder_hidden_states (`torch.FloatTensor`):
829
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
760
830
  encoder_attention_mask (`torch.Tensor`):
761
- (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
762
- discard. Mask will be converted into a bias, which adds large negative values to attention scores
763
- corresponding to "discard" tokens.
831
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
832
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
833
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
764
834
  return_dict (`bool`, *optional*, defaults to `True`):
765
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
835
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
836
+ tuple.
766
837
  cross_attention_kwargs (`dict`, *optional*):
767
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
768
- `self.processor` in
769
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
770
- added_cond_kwargs (`dict`, *optional*):
771
- A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
772
- embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
773
- `addition_embed_type` for more information.
838
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
839
+ added_cond_kwargs: (`dict`, *optional*):
840
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
841
+ are passed along to the UNet blocks.
774
842
 
775
843
  Returns:
776
844
  [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
777
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
778
- returning a tuple, the first element is the sample tensor.
845
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
846
+ a `tuple` is returned where the first element is the sample tensor.
779
847
  """
780
848
  # By default samples have to be AT least a multiple of the overall upsampling factor.
781
849
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -841,6 +909,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
841
909
  t_emb = t_emb.to(dtype=sample.dtype)
842
910
 
843
911
  emb = self.time_embedding(t_emb, timestep_cond)
912
+ aug_emb = None
844
913
 
845
914
  if self.class_embedding is not None:
846
915
  if class_labels is None:
@@ -862,9 +931,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
862
931
 
863
932
  if self.config.addition_embed_type == "text":
864
933
  aug_emb = self.add_embedding(encoder_hidden_states)
865
- emb = emb + aug_emb
866
934
  elif self.config.addition_embed_type == "text_image":
867
- # Kadinsky 2.1 - style
935
+ # Kandinsky 2.1 - style
868
936
  if "image_embeds" not in added_cond_kwargs:
869
937
  raise ValueError(
870
938
  f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
@@ -873,9 +941,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
873
941
 
874
942
  image_embs = added_cond_kwargs.get("image_embeds")
875
943
  text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
876
-
877
944
  aug_emb = self.add_embedding(text_embs, image_embs)
878
- emb = emb + aug_emb
945
+ elif self.config.addition_embed_type == "text_time":
946
+ if "text_embeds" not in added_cond_kwargs:
947
+ raise ValueError(
948
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
949
+ " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
950
+ )
951
+ text_embeds = added_cond_kwargs.get("text_embeds")
952
+ if "time_ids" not in added_cond_kwargs:
953
+ raise ValueError(
954
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
955
+ " the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
956
+ )
957
+ time_ids = added_cond_kwargs.get("time_ids")
958
+ time_embeds = self.add_time_proj(time_ids.flatten())
959
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
960
+
961
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
962
+ add_embeds = add_embeds.to(emb.dtype)
963
+ aug_emb = self.add_embedding(add_embeds)
964
+ elif self.config.addition_embed_type == "image":
965
+ # Kandinsky 2.2 - style
966
+ if "image_embeds" not in added_cond_kwargs:
967
+ raise ValueError(
968
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
969
+ " keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
970
+ )
971
+ image_embs = added_cond_kwargs.get("image_embeds")
972
+ aug_emb = self.add_embedding(image_embs)
973
+ elif self.config.addition_embed_type == "image_hint":
974
+ # Kandinsky 2.2 - style
975
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
976
+ raise ValueError(
977
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
978
+ " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
979
+ )
980
+ image_embs = added_cond_kwargs.get("image_embeds")
981
+ hint = added_cond_kwargs.get("hint")
982
+ aug_emb, hint = self.add_embedding(image_embs, hint)
983
+ sample = torch.cat([sample, hint], dim=1)
984
+
985
+ emb = emb + aug_emb if aug_emb is not None else emb
879
986
 
880
987
  if self.time_embed_act is not None:
881
988
  emb = self.time_embed_act(emb)
@@ -892,7 +999,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
892
999
 
893
1000
  image_embeds = added_cond_kwargs.get("image_embeds")
894
1001
  encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
895
-
1002
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1003
+ # Kandinsky 2.2 - style
1004
+ if "image_embeds" not in added_cond_kwargs:
1005
+ raise ValueError(
1006
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
1007
+ " the keyword argument `image_embeds` to be passed in `added_conditions`"
1008
+ )
1009
+ image_embeds = added_cond_kwargs.get("image_embeds")
1010
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
896
1011
  # 2. pre-process
897
1012
  sample = self.conv_in(sample)
898
1013
 
@@ -1187,12 +1302,13 @@ class CrossAttnDownBlockFlat(nn.Module):
1187
1302
  temb_channels: int,
1188
1303
  dropout: float = 0.0,
1189
1304
  num_layers: int = 1,
1305
+ transformer_layers_per_block: int = 1,
1190
1306
  resnet_eps: float = 1e-6,
1191
1307
  resnet_time_scale_shift: str = "default",
1192
1308
  resnet_act_fn: str = "swish",
1193
1309
  resnet_groups: int = 32,
1194
1310
  resnet_pre_norm: bool = True,
1195
- attn_num_head_channels=1,
1311
+ num_attention_heads=1,
1196
1312
  cross_attention_dim=1280,
1197
1313
  output_scale_factor=1.0,
1198
1314
  downsample_padding=1,
@@ -1207,7 +1323,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1207
1323
  attentions = []
1208
1324
 
1209
1325
  self.has_cross_attention = True
1210
- self.attn_num_head_channels = attn_num_head_channels
1326
+ self.num_attention_heads = num_attention_heads
1211
1327
 
1212
1328
  for i in range(num_layers):
1213
1329
  in_channels = in_channels if i == 0 else out_channels
@@ -1228,10 +1344,10 @@ class CrossAttnDownBlockFlat(nn.Module):
1228
1344
  if not dual_cross_attention:
1229
1345
  attentions.append(
1230
1346
  Transformer2DModel(
1231
- attn_num_head_channels,
1232
- out_channels // attn_num_head_channels,
1347
+ num_attention_heads,
1348
+ out_channels // num_attention_heads,
1233
1349
  in_channels=out_channels,
1234
- num_layers=1,
1350
+ num_layers=transformer_layers_per_block,
1235
1351
  cross_attention_dim=cross_attention_dim,
1236
1352
  norm_num_groups=resnet_groups,
1237
1353
  use_linear_projection=use_linear_projection,
@@ -1242,8 +1358,8 @@ class CrossAttnDownBlockFlat(nn.Module):
1242
1358
  else:
1243
1359
  attentions.append(
1244
1360
  DualTransformer2DModel(
1245
- attn_num_head_channels,
1246
- out_channels // attn_num_head_channels,
1361
+ num_attention_heads,
1362
+ out_channels // num_attention_heads,
1247
1363
  in_channels=out_channels,
1248
1364
  num_layers=1,
1249
1365
  cross_attention_dim=cross_attention_dim,
@@ -1421,12 +1537,13 @@ class CrossAttnUpBlockFlat(nn.Module):
1421
1537
  temb_channels: int,
1422
1538
  dropout: float = 0.0,
1423
1539
  num_layers: int = 1,
1540
+ transformer_layers_per_block: int = 1,
1424
1541
  resnet_eps: float = 1e-6,
1425
1542
  resnet_time_scale_shift: str = "default",
1426
1543
  resnet_act_fn: str = "swish",
1427
1544
  resnet_groups: int = 32,
1428
1545
  resnet_pre_norm: bool = True,
1429
- attn_num_head_channels=1,
1546
+ num_attention_heads=1,
1430
1547
  cross_attention_dim=1280,
1431
1548
  output_scale_factor=1.0,
1432
1549
  add_upsample=True,
@@ -1440,7 +1557,7 @@ class CrossAttnUpBlockFlat(nn.Module):
1440
1557
  attentions = []
1441
1558
 
1442
1559
  self.has_cross_attention = True
1443
- self.attn_num_head_channels = attn_num_head_channels
1560
+ self.num_attention_heads = num_attention_heads
1444
1561
 
1445
1562
  for i in range(num_layers):
1446
1563
  res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -1463,10 +1580,10 @@ class CrossAttnUpBlockFlat(nn.Module):
1463
1580
  if not dual_cross_attention:
1464
1581
  attentions.append(
1465
1582
  Transformer2DModel(
1466
- attn_num_head_channels,
1467
- out_channels // attn_num_head_channels,
1583
+ num_attention_heads,
1584
+ out_channels // num_attention_heads,
1468
1585
  in_channels=out_channels,
1469
- num_layers=1,
1586
+ num_layers=transformer_layers_per_block,
1470
1587
  cross_attention_dim=cross_attention_dim,
1471
1588
  norm_num_groups=resnet_groups,
1472
1589
  use_linear_projection=use_linear_projection,
@@ -1477,8 +1594,8 @@ class CrossAttnUpBlockFlat(nn.Module):
1477
1594
  else:
1478
1595
  attentions.append(
1479
1596
  DualTransformer2DModel(
1480
- attn_num_head_channels,
1481
- out_channels // attn_num_head_channels,
1597
+ num_attention_heads,
1598
+ out_channels // num_attention_heads,
1482
1599
  in_channels=out_channels,
1483
1600
  num_layers=1,
1484
1601
  cross_attention_dim=cross_attention_dim,
@@ -1567,12 +1684,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1567
1684
  temb_channels: int,
1568
1685
  dropout: float = 0.0,
1569
1686
  num_layers: int = 1,
1687
+ transformer_layers_per_block: int = 1,
1570
1688
  resnet_eps: float = 1e-6,
1571
1689
  resnet_time_scale_shift: str = "default",
1572
1690
  resnet_act_fn: str = "swish",
1573
1691
  resnet_groups: int = 32,
1574
1692
  resnet_pre_norm: bool = True,
1575
- attn_num_head_channels=1,
1693
+ num_attention_heads=1,
1576
1694
  output_scale_factor=1.0,
1577
1695
  cross_attention_dim=1280,
1578
1696
  dual_cross_attention=False,
@@ -1582,7 +1700,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1582
1700
  super().__init__()
1583
1701
 
1584
1702
  self.has_cross_attention = True
1585
- self.attn_num_head_channels = attn_num_head_channels
1703
+ self.num_attention_heads = num_attention_heads
1586
1704
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1587
1705
 
1588
1706
  # there is always at least one resnet
@@ -1606,10 +1724,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1606
1724
  if not dual_cross_attention:
1607
1725
  attentions.append(
1608
1726
  Transformer2DModel(
1609
- attn_num_head_channels,
1610
- in_channels // attn_num_head_channels,
1727
+ num_attention_heads,
1728
+ in_channels // num_attention_heads,
1611
1729
  in_channels=in_channels,
1612
- num_layers=1,
1730
+ num_layers=transformer_layers_per_block,
1613
1731
  cross_attention_dim=cross_attention_dim,
1614
1732
  norm_num_groups=resnet_groups,
1615
1733
  use_linear_projection=use_linear_projection,
@@ -1619,8 +1737,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1619
1737
  else:
1620
1738
  attentions.append(
1621
1739
  DualTransformer2DModel(
1622
- attn_num_head_channels,
1623
- in_channels // attn_num_head_channels,
1740
+ num_attention_heads,
1741
+ in_channels // num_attention_heads,
1624
1742
  in_channels=in_channels,
1625
1743
  num_layers=1,
1626
1744
  cross_attention_dim=cross_attention_dim,
@@ -1682,7 +1800,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
1682
1800
  resnet_act_fn: str = "swish",
1683
1801
  resnet_groups: int = 32,
1684
1802
  resnet_pre_norm: bool = True,
1685
- attn_num_head_channels=1,
1803
+ attention_head_dim=1,
1686
1804
  output_scale_factor=1.0,
1687
1805
  cross_attention_dim=1280,
1688
1806
  skip_time_act=False,
@@ -1693,10 +1811,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
1693
1811
 
1694
1812
  self.has_cross_attention = True
1695
1813
 
1696
- self.attn_num_head_channels = attn_num_head_channels
1814
+ self.attention_head_dim = attention_head_dim
1697
1815
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1698
1816
 
1699
- self.num_heads = in_channels // self.attn_num_head_channels
1817
+ self.num_heads = in_channels // self.attention_head_dim
1700
1818
 
1701
1819
  # there is always at least one resnet
1702
1820
  resnets = [
@@ -1726,7 +1844,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
1726
1844
  query_dim=in_channels,
1727
1845
  cross_attention_dim=in_channels,
1728
1846
  heads=self.num_heads,
1729
- dim_head=attn_num_head_channels,
1847
+ dim_head=self.attention_head_dim,
1730
1848
  added_kv_proj_dim=cross_attention_dim,
1731
1849
  norm_num_groups=resnet_groups,
1732
1850
  bias=True,