diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch
18
18
  import torch.nn.functional as F
19
19
  from torch import nn
20
20
 
21
- from ..utils import is_torch_version
21
+ from ..utils import is_torch_version, logging
22
22
  from .attention import AdaGroupNorm
23
23
  from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
24
  from .dual_transformer_2d import DualTransformer2DModel
@@ -26,6 +26,9 @@ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D,
26
26
  from .transformer_2d import Transformer2DModel
27
27
 
28
28
 
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
29
32
  def get_down_block(
30
33
  down_block_type,
31
34
  num_layers,
@@ -35,7 +38,8 @@ def get_down_block(
35
38
  add_downsample,
36
39
  resnet_eps,
37
40
  resnet_act_fn,
38
- attn_num_head_channels,
41
+ transformer_layers_per_block=1,
42
+ num_attention_heads=None,
39
43
  resnet_groups=None,
40
44
  cross_attention_dim=None,
41
45
  downsample_padding=None,
@@ -47,7 +51,16 @@ def get_down_block(
47
51
  resnet_skip_time_act=False,
48
52
  resnet_out_scale_factor=1.0,
49
53
  cross_attention_norm=None,
54
+ attention_head_dim=None,
55
+ downsample_type=None,
50
56
  ):
57
+ # If attn head dim is not defined, we default it to the number of heads
58
+ if attention_head_dim is None:
59
+ logger.warn(
60
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
61
+ )
62
+ attention_head_dim = num_attention_heads
63
+
51
64
  down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
52
65
  if down_block_type == "DownBlock2D":
53
66
  return DownBlock2D(
@@ -77,24 +90,29 @@ def get_down_block(
77
90
  output_scale_factor=resnet_out_scale_factor,
78
91
  )
79
92
  elif down_block_type == "AttnDownBlock2D":
93
+ if add_downsample is False:
94
+ downsample_type = None
95
+ else:
96
+ downsample_type = downsample_type or "conv" # default to 'conv'
80
97
  return AttnDownBlock2D(
81
98
  num_layers=num_layers,
82
99
  in_channels=in_channels,
83
100
  out_channels=out_channels,
84
101
  temb_channels=temb_channels,
85
- add_downsample=add_downsample,
86
102
  resnet_eps=resnet_eps,
87
103
  resnet_act_fn=resnet_act_fn,
88
104
  resnet_groups=resnet_groups,
89
105
  downsample_padding=downsample_padding,
90
- attn_num_head_channels=attn_num_head_channels,
106
+ attention_head_dim=attention_head_dim,
91
107
  resnet_time_scale_shift=resnet_time_scale_shift,
108
+ downsample_type=downsample_type,
92
109
  )
93
110
  elif down_block_type == "CrossAttnDownBlock2D":
94
111
  if cross_attention_dim is None:
95
112
  raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
96
113
  return CrossAttnDownBlock2D(
97
114
  num_layers=num_layers,
115
+ transformer_layers_per_block=transformer_layers_per_block,
98
116
  in_channels=in_channels,
99
117
  out_channels=out_channels,
100
118
  temb_channels=temb_channels,
@@ -104,7 +122,7 @@ def get_down_block(
104
122
  resnet_groups=resnet_groups,
105
123
  downsample_padding=downsample_padding,
106
124
  cross_attention_dim=cross_attention_dim,
107
- attn_num_head_channels=attn_num_head_channels,
125
+ num_attention_heads=num_attention_heads,
108
126
  dual_cross_attention=dual_cross_attention,
109
127
  use_linear_projection=use_linear_projection,
110
128
  only_cross_attention=only_cross_attention,
@@ -124,7 +142,7 @@ def get_down_block(
124
142
  resnet_act_fn=resnet_act_fn,
125
143
  resnet_groups=resnet_groups,
126
144
  cross_attention_dim=cross_attention_dim,
127
- attn_num_head_channels=attn_num_head_channels,
145
+ attention_head_dim=attention_head_dim,
128
146
  resnet_time_scale_shift=resnet_time_scale_shift,
129
147
  skip_time_act=resnet_skip_time_act,
130
148
  output_scale_factor=resnet_out_scale_factor,
@@ -152,8 +170,7 @@ def get_down_block(
152
170
  add_downsample=add_downsample,
153
171
  resnet_eps=resnet_eps,
154
172
  resnet_act_fn=resnet_act_fn,
155
- downsample_padding=downsample_padding,
156
- attn_num_head_channels=attn_num_head_channels,
173
+ attention_head_dim=attention_head_dim,
157
174
  resnet_time_scale_shift=resnet_time_scale_shift,
158
175
  )
159
176
  elif down_block_type == "DownEncoderBlock2D":
@@ -178,7 +195,7 @@ def get_down_block(
178
195
  resnet_act_fn=resnet_act_fn,
179
196
  resnet_groups=resnet_groups,
180
197
  downsample_padding=downsample_padding,
181
- attn_num_head_channels=attn_num_head_channels,
198
+ attention_head_dim=attention_head_dim,
182
199
  resnet_time_scale_shift=resnet_time_scale_shift,
183
200
  )
184
201
  elif down_block_type == "KDownBlock2D":
@@ -201,7 +218,7 @@ def get_down_block(
201
218
  resnet_eps=resnet_eps,
202
219
  resnet_act_fn=resnet_act_fn,
203
220
  cross_attention_dim=cross_attention_dim,
204
- attn_num_head_channels=attn_num_head_channels,
221
+ attention_head_dim=attention_head_dim,
205
222
  add_self_attention=True if not add_downsample else False,
206
223
  )
207
224
  raise ValueError(f"{down_block_type} does not exist.")
@@ -217,7 +234,8 @@ def get_up_block(
217
234
  add_upsample,
218
235
  resnet_eps,
219
236
  resnet_act_fn,
220
- attn_num_head_channels,
237
+ transformer_layers_per_block=1,
238
+ num_attention_heads=None,
221
239
  resnet_groups=None,
222
240
  cross_attention_dim=None,
223
241
  dual_cross_attention=False,
@@ -228,7 +246,16 @@ def get_up_block(
228
246
  resnet_skip_time_act=False,
229
247
  resnet_out_scale_factor=1.0,
230
248
  cross_attention_norm=None,
249
+ attention_head_dim=None,
250
+ upsample_type=None,
231
251
  ):
252
+ # If attn head dim is not defined, we default it to the number of heads
253
+ if attention_head_dim is None:
254
+ logger.warn(
255
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
256
+ )
257
+ attention_head_dim = num_attention_heads
258
+
232
259
  up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
233
260
  if up_block_type == "UpBlock2D":
234
261
  return UpBlock2D(
@@ -263,6 +290,7 @@ def get_up_block(
263
290
  raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
264
291
  return CrossAttnUpBlock2D(
265
292
  num_layers=num_layers,
293
+ transformer_layers_per_block=transformer_layers_per_block,
266
294
  in_channels=in_channels,
267
295
  out_channels=out_channels,
268
296
  prev_output_channel=prev_output_channel,
@@ -272,7 +300,7 @@ def get_up_block(
272
300
  resnet_act_fn=resnet_act_fn,
273
301
  resnet_groups=resnet_groups,
274
302
  cross_attention_dim=cross_attention_dim,
275
- attn_num_head_channels=attn_num_head_channels,
303
+ num_attention_heads=num_attention_heads,
276
304
  dual_cross_attention=dual_cross_attention,
277
305
  use_linear_projection=use_linear_projection,
278
306
  only_cross_attention=only_cross_attention,
@@ -293,7 +321,7 @@ def get_up_block(
293
321
  resnet_act_fn=resnet_act_fn,
294
322
  resnet_groups=resnet_groups,
295
323
  cross_attention_dim=cross_attention_dim,
296
- attn_num_head_channels=attn_num_head_channels,
324
+ attention_head_dim=attention_head_dim,
297
325
  resnet_time_scale_shift=resnet_time_scale_shift,
298
326
  skip_time_act=resnet_skip_time_act,
299
327
  output_scale_factor=resnet_out_scale_factor,
@@ -301,18 +329,23 @@ def get_up_block(
301
329
  cross_attention_norm=cross_attention_norm,
302
330
  )
303
331
  elif up_block_type == "AttnUpBlock2D":
332
+ if add_upsample is False:
333
+ upsample_type = None
334
+ else:
335
+ upsample_type = upsample_type or "conv" # default to 'conv'
336
+
304
337
  return AttnUpBlock2D(
305
338
  num_layers=num_layers,
306
339
  in_channels=in_channels,
307
340
  out_channels=out_channels,
308
341
  prev_output_channel=prev_output_channel,
309
342
  temb_channels=temb_channels,
310
- add_upsample=add_upsample,
311
343
  resnet_eps=resnet_eps,
312
344
  resnet_act_fn=resnet_act_fn,
313
345
  resnet_groups=resnet_groups,
314
- attn_num_head_channels=attn_num_head_channels,
346
+ attention_head_dim=attention_head_dim,
315
347
  resnet_time_scale_shift=resnet_time_scale_shift,
348
+ upsample_type=upsample_type,
316
349
  )
317
350
  elif up_block_type == "SkipUpBlock2D":
318
351
  return SkipUpBlock2D(
@@ -336,7 +369,7 @@ def get_up_block(
336
369
  add_upsample=add_upsample,
337
370
  resnet_eps=resnet_eps,
338
371
  resnet_act_fn=resnet_act_fn,
339
- attn_num_head_channels=attn_num_head_channels,
372
+ attention_head_dim=attention_head_dim,
340
373
  resnet_time_scale_shift=resnet_time_scale_shift,
341
374
  )
342
375
  elif up_block_type == "UpDecoderBlock2D":
@@ -360,7 +393,7 @@ def get_up_block(
360
393
  resnet_eps=resnet_eps,
361
394
  resnet_act_fn=resnet_act_fn,
362
395
  resnet_groups=resnet_groups,
363
- attn_num_head_channels=attn_num_head_channels,
396
+ attention_head_dim=attention_head_dim,
364
397
  resnet_time_scale_shift=resnet_time_scale_shift,
365
398
  temb_channels=temb_channels,
366
399
  )
@@ -384,7 +417,7 @@ def get_up_block(
384
417
  resnet_eps=resnet_eps,
385
418
  resnet_act_fn=resnet_act_fn,
386
419
  cross_attention_dim=cross_attention_dim,
387
- attn_num_head_channels=attn_num_head_channels,
420
+ attention_head_dim=attention_head_dim,
388
421
  )
389
422
 
390
423
  raise ValueError(f"{up_block_type} does not exist.")
@@ -403,7 +436,7 @@ class UNetMidBlock2D(nn.Module):
403
436
  resnet_groups: int = 32,
404
437
  resnet_pre_norm: bool = True,
405
438
  add_attention: bool = True,
406
- attn_num_head_channels=1,
439
+ attention_head_dim=1,
407
440
  output_scale_factor=1.0,
408
441
  ):
409
442
  super().__init__()
@@ -427,13 +460,19 @@ class UNetMidBlock2D(nn.Module):
427
460
  ]
428
461
  attentions = []
429
462
 
463
+ if attention_head_dim is None:
464
+ logger.warn(
465
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
466
+ )
467
+ attention_head_dim = in_channels
468
+
430
469
  for _ in range(num_layers):
431
470
  if self.add_attention:
432
471
  attentions.append(
433
472
  Attention(
434
473
  in_channels,
435
- heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
436
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
474
+ heads=in_channels // attention_head_dim,
475
+ dim_head=attention_head_dim,
437
476
  rescale_output_factor=output_scale_factor,
438
477
  eps=resnet_eps,
439
478
  norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
@@ -482,12 +521,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
482
521
  temb_channels: int,
483
522
  dropout: float = 0.0,
484
523
  num_layers: int = 1,
524
+ transformer_layers_per_block: int = 1,
485
525
  resnet_eps: float = 1e-6,
486
526
  resnet_time_scale_shift: str = "default",
487
527
  resnet_act_fn: str = "swish",
488
528
  resnet_groups: int = 32,
489
529
  resnet_pre_norm: bool = True,
490
- attn_num_head_channels=1,
530
+ num_attention_heads=1,
491
531
  output_scale_factor=1.0,
492
532
  cross_attention_dim=1280,
493
533
  dual_cross_attention=False,
@@ -497,7 +537,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
497
537
  super().__init__()
498
538
 
499
539
  self.has_cross_attention = True
500
- self.attn_num_head_channels = attn_num_head_channels
540
+ self.num_attention_heads = num_attention_heads
501
541
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
502
542
 
503
543
  # there is always at least one resnet
@@ -521,10 +561,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
521
561
  if not dual_cross_attention:
522
562
  attentions.append(
523
563
  Transformer2DModel(
524
- attn_num_head_channels,
525
- in_channels // attn_num_head_channels,
564
+ num_attention_heads,
565
+ in_channels // num_attention_heads,
526
566
  in_channels=in_channels,
527
- num_layers=1,
567
+ num_layers=transformer_layers_per_block,
528
568
  cross_attention_dim=cross_attention_dim,
529
569
  norm_num_groups=resnet_groups,
530
570
  use_linear_projection=use_linear_projection,
@@ -534,8 +574,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
534
574
  else:
535
575
  attentions.append(
536
576
  DualTransformer2DModel(
537
- attn_num_head_channels,
538
- in_channels // attn_num_head_channels,
577
+ num_attention_heads,
578
+ in_channels // num_attention_heads,
539
579
  in_channels=in_channels,
540
580
  num_layers=1,
541
581
  cross_attention_dim=cross_attention_dim,
@@ -596,7 +636,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
596
636
  resnet_act_fn: str = "swish",
597
637
  resnet_groups: int = 32,
598
638
  resnet_pre_norm: bool = True,
599
- attn_num_head_channels=1,
639
+ attention_head_dim=1,
600
640
  output_scale_factor=1.0,
601
641
  cross_attention_dim=1280,
602
642
  skip_time_act=False,
@@ -607,10 +647,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
607
647
 
608
648
  self.has_cross_attention = True
609
649
 
610
- self.attn_num_head_channels = attn_num_head_channels
650
+ self.attention_head_dim = attention_head_dim
611
651
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
612
652
 
613
- self.num_heads = in_channels // self.attn_num_head_channels
653
+ self.num_heads = in_channels // self.attention_head_dim
614
654
 
615
655
  # there is always at least one resnet
616
656
  resnets = [
@@ -640,7 +680,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
640
680
  query_dim=in_channels,
641
681
  cross_attention_dim=in_channels,
642
682
  heads=self.num_heads,
643
- dim_head=attn_num_head_channels,
683
+ dim_head=self.attention_head_dim,
644
684
  added_kv_proj_dim=cross_attention_dim,
645
685
  norm_num_groups=resnet_groups,
646
686
  bias=True,
@@ -720,14 +760,21 @@ class AttnDownBlock2D(nn.Module):
720
760
  resnet_act_fn: str = "swish",
721
761
  resnet_groups: int = 32,
722
762
  resnet_pre_norm: bool = True,
723
- attn_num_head_channels=1,
763
+ attention_head_dim=1,
724
764
  output_scale_factor=1.0,
725
765
  downsample_padding=1,
726
- add_downsample=True,
766
+ downsample_type="conv",
727
767
  ):
728
768
  super().__init__()
729
769
  resnets = []
730
770
  attentions = []
771
+ self.downsample_type = downsample_type
772
+
773
+ if attention_head_dim is None:
774
+ logger.warn(
775
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
776
+ )
777
+ attention_head_dim = out_channels
731
778
 
732
779
  for i in range(num_layers):
733
780
  in_channels = in_channels if i == 0 else out_channels
@@ -748,8 +795,8 @@ class AttnDownBlock2D(nn.Module):
748
795
  attentions.append(
749
796
  Attention(
750
797
  out_channels,
751
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
752
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
798
+ heads=out_channels // attention_head_dim,
799
+ dim_head=attention_head_dim,
753
800
  rescale_output_factor=output_scale_factor,
754
801
  eps=resnet_eps,
755
802
  norm_num_groups=resnet_groups,
@@ -763,7 +810,7 @@ class AttnDownBlock2D(nn.Module):
763
810
  self.attentions = nn.ModuleList(attentions)
764
811
  self.resnets = nn.ModuleList(resnets)
765
812
 
766
- if add_downsample:
813
+ if downsample_type == "conv":
767
814
  self.downsamplers = nn.ModuleList(
768
815
  [
769
816
  Downsample2D(
@@ -771,6 +818,24 @@ class AttnDownBlock2D(nn.Module):
771
818
  )
772
819
  ]
773
820
  )
821
+ elif downsample_type == "resnet":
822
+ self.downsamplers = nn.ModuleList(
823
+ [
824
+ ResnetBlock2D(
825
+ in_channels=out_channels,
826
+ out_channels=out_channels,
827
+ temb_channels=temb_channels,
828
+ eps=resnet_eps,
829
+ groups=resnet_groups,
830
+ dropout=dropout,
831
+ time_embedding_norm=resnet_time_scale_shift,
832
+ non_linearity=resnet_act_fn,
833
+ output_scale_factor=output_scale_factor,
834
+ pre_norm=resnet_pre_norm,
835
+ down=True,
836
+ )
837
+ ]
838
+ )
774
839
  else:
775
840
  self.downsamplers = None
776
841
 
@@ -780,11 +845,14 @@ class AttnDownBlock2D(nn.Module):
780
845
  for resnet, attn in zip(self.resnets, self.attentions):
781
846
  hidden_states = resnet(hidden_states, temb)
782
847
  hidden_states = attn(hidden_states)
783
- output_states += (hidden_states,)
848
+ output_states = output_states + (hidden_states,)
784
849
 
785
850
  if self.downsamplers is not None:
786
851
  for downsampler in self.downsamplers:
787
- hidden_states = downsampler(hidden_states)
852
+ if self.downsample_type == "resnet":
853
+ hidden_states = downsampler(hidden_states, temb=temb)
854
+ else:
855
+ hidden_states = downsampler(hidden_states)
788
856
 
789
857
  output_states += (hidden_states,)
790
858
 
@@ -799,12 +867,13 @@ class CrossAttnDownBlock2D(nn.Module):
799
867
  temb_channels: int,
800
868
  dropout: float = 0.0,
801
869
  num_layers: int = 1,
870
+ transformer_layers_per_block: int = 1,
802
871
  resnet_eps: float = 1e-6,
803
872
  resnet_time_scale_shift: str = "default",
804
873
  resnet_act_fn: str = "swish",
805
874
  resnet_groups: int = 32,
806
875
  resnet_pre_norm: bool = True,
807
- attn_num_head_channels=1,
876
+ num_attention_heads=1,
808
877
  cross_attention_dim=1280,
809
878
  output_scale_factor=1.0,
810
879
  downsample_padding=1,
@@ -819,7 +888,7 @@ class CrossAttnDownBlock2D(nn.Module):
819
888
  attentions = []
820
889
 
821
890
  self.has_cross_attention = True
822
- self.attn_num_head_channels = attn_num_head_channels
891
+ self.num_attention_heads = num_attention_heads
823
892
 
824
893
  for i in range(num_layers):
825
894
  in_channels = in_channels if i == 0 else out_channels
@@ -840,10 +909,10 @@ class CrossAttnDownBlock2D(nn.Module):
840
909
  if not dual_cross_attention:
841
910
  attentions.append(
842
911
  Transformer2DModel(
843
- attn_num_head_channels,
844
- out_channels // attn_num_head_channels,
912
+ num_attention_heads,
913
+ out_channels // num_attention_heads,
845
914
  in_channels=out_channels,
846
- num_layers=1,
915
+ num_layers=transformer_layers_per_block,
847
916
  cross_attention_dim=cross_attention_dim,
848
917
  norm_num_groups=resnet_groups,
849
918
  use_linear_projection=use_linear_projection,
@@ -854,8 +923,8 @@ class CrossAttnDownBlock2D(nn.Module):
854
923
  else:
855
924
  attentions.append(
856
925
  DualTransformer2DModel(
857
- attn_num_head_channels,
858
- out_channels // attn_num_head_channels,
926
+ num_attention_heads,
927
+ out_channels // num_attention_heads,
859
928
  in_channels=out_channels,
860
929
  num_layers=1,
861
930
  cross_attention_dim=cross_attention_dim,
@@ -1099,7 +1168,7 @@ class AttnDownEncoderBlock2D(nn.Module):
1099
1168
  resnet_act_fn: str = "swish",
1100
1169
  resnet_groups: int = 32,
1101
1170
  resnet_pre_norm: bool = True,
1102
- attn_num_head_channels=1,
1171
+ attention_head_dim=1,
1103
1172
  output_scale_factor=1.0,
1104
1173
  add_downsample=True,
1105
1174
  downsample_padding=1,
@@ -1108,6 +1177,12 @@ class AttnDownEncoderBlock2D(nn.Module):
1108
1177
  resnets = []
1109
1178
  attentions = []
1110
1179
 
1180
+ if attention_head_dim is None:
1181
+ logger.warn(
1182
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
1183
+ )
1184
+ attention_head_dim = out_channels
1185
+
1111
1186
  for i in range(num_layers):
1112
1187
  in_channels = in_channels if i == 0 else out_channels
1113
1188
  resnets.append(
@@ -1127,8 +1202,8 @@ class AttnDownEncoderBlock2D(nn.Module):
1127
1202
  attentions.append(
1128
1203
  Attention(
1129
1204
  out_channels,
1130
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
1131
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
1205
+ heads=out_channels // attention_head_dim,
1206
+ dim_head=attention_head_dim,
1132
1207
  rescale_output_factor=output_scale_factor,
1133
1208
  eps=resnet_eps,
1134
1209
  norm_num_groups=resnet_groups,
@@ -1177,15 +1252,20 @@ class AttnSkipDownBlock2D(nn.Module):
1177
1252
  resnet_time_scale_shift: str = "default",
1178
1253
  resnet_act_fn: str = "swish",
1179
1254
  resnet_pre_norm: bool = True,
1180
- attn_num_head_channels=1,
1255
+ attention_head_dim=1,
1181
1256
  output_scale_factor=np.sqrt(2.0),
1182
- downsample_padding=1,
1183
1257
  add_downsample=True,
1184
1258
  ):
1185
1259
  super().__init__()
1186
1260
  self.attentions = nn.ModuleList([])
1187
1261
  self.resnets = nn.ModuleList([])
1188
1262
 
1263
+ if attention_head_dim is None:
1264
+ logger.warn(
1265
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
1266
+ )
1267
+ attention_head_dim = out_channels
1268
+
1189
1269
  for i in range(num_layers):
1190
1270
  in_channels = in_channels if i == 0 else out_channels
1191
1271
  self.resnets.append(
@@ -1206,8 +1286,8 @@ class AttnSkipDownBlock2D(nn.Module):
1206
1286
  self.attentions.append(
1207
1287
  Attention(
1208
1288
  out_channels,
1209
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
1210
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
1289
+ heads=out_channels // attention_head_dim,
1290
+ dim_head=attention_head_dim,
1211
1291
  rescale_output_factor=output_scale_factor,
1212
1292
  eps=resnet_eps,
1213
1293
  norm_num_groups=32,
@@ -1451,7 +1531,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1451
1531
  resnet_act_fn: str = "swish",
1452
1532
  resnet_groups: int = 32,
1453
1533
  resnet_pre_norm: bool = True,
1454
- attn_num_head_channels=1,
1534
+ attention_head_dim=1,
1455
1535
  cross_attention_dim=1280,
1456
1536
  output_scale_factor=1.0,
1457
1537
  add_downsample=True,
@@ -1466,8 +1546,8 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1466
1546
  resnets = []
1467
1547
  attentions = []
1468
1548
 
1469
- self.attn_num_head_channels = attn_num_head_channels
1470
- self.num_heads = out_channels // self.attn_num_head_channels
1549
+ self.attention_head_dim = attention_head_dim
1550
+ self.num_heads = out_channels // self.attention_head_dim
1471
1551
 
1472
1552
  for i in range(num_layers):
1473
1553
  in_channels = in_channels if i == 0 else out_channels
@@ -1496,7 +1576,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1496
1576
  query_dim=out_channels,
1497
1577
  cross_attention_dim=out_channels,
1498
1578
  heads=self.num_heads,
1499
- dim_head=attn_num_head_channels,
1579
+ dim_head=attention_head_dim,
1500
1580
  added_kv_proj_dim=cross_attention_dim,
1501
1581
  norm_num_groups=resnet_groups,
1502
1582
  bias=True,
@@ -1686,7 +1766,7 @@ class KCrossAttnDownBlock2D(nn.Module):
1686
1766
  num_layers: int = 4,
1687
1767
  resnet_group_size: int = 32,
1688
1768
  add_downsample=True,
1689
- attn_num_head_channels: int = 64,
1769
+ attention_head_dim: int = 64,
1690
1770
  add_self_attention: bool = False,
1691
1771
  resnet_eps: float = 1e-5,
1692
1772
  resnet_act_fn: str = "gelu",
@@ -1719,8 +1799,8 @@ class KCrossAttnDownBlock2D(nn.Module):
1719
1799
  attentions.append(
1720
1800
  KAttentionBlock(
1721
1801
  out_channels,
1722
- out_channels // attn_num_head_channels,
1723
- attn_num_head_channels,
1802
+ out_channels // attention_head_dim,
1803
+ attention_head_dim,
1724
1804
  cross_attention_dim=cross_attention_dim,
1725
1805
  temb_channels=temb_channels,
1726
1806
  attention_bias=True,
@@ -1817,14 +1897,22 @@ class AttnUpBlock2D(nn.Module):
1817
1897
  resnet_act_fn: str = "swish",
1818
1898
  resnet_groups: int = 32,
1819
1899
  resnet_pre_norm: bool = True,
1820
- attn_num_head_channels=1,
1900
+ attention_head_dim=1,
1821
1901
  output_scale_factor=1.0,
1822
- add_upsample=True,
1902
+ upsample_type="conv",
1823
1903
  ):
1824
1904
  super().__init__()
1825
1905
  resnets = []
1826
1906
  attentions = []
1827
1907
 
1908
+ self.upsample_type = upsample_type
1909
+
1910
+ if attention_head_dim is None:
1911
+ logger.warn(
1912
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
1913
+ )
1914
+ attention_head_dim = out_channels
1915
+
1828
1916
  for i in range(num_layers):
1829
1917
  res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1830
1918
  resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1846,8 +1934,8 @@ class AttnUpBlock2D(nn.Module):
1846
1934
  attentions.append(
1847
1935
  Attention(
1848
1936
  out_channels,
1849
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
1850
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
1937
+ heads=out_channels // attention_head_dim,
1938
+ dim_head=attention_head_dim,
1851
1939
  rescale_output_factor=output_scale_factor,
1852
1940
  eps=resnet_eps,
1853
1941
  norm_num_groups=resnet_groups,
@@ -1861,8 +1949,26 @@ class AttnUpBlock2D(nn.Module):
1861
1949
  self.attentions = nn.ModuleList(attentions)
1862
1950
  self.resnets = nn.ModuleList(resnets)
1863
1951
 
1864
- if add_upsample:
1952
+ if upsample_type == "conv":
1865
1953
  self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1954
+ elif upsample_type == "resnet":
1955
+ self.upsamplers = nn.ModuleList(
1956
+ [
1957
+ ResnetBlock2D(
1958
+ in_channels=out_channels,
1959
+ out_channels=out_channels,
1960
+ temb_channels=temb_channels,
1961
+ eps=resnet_eps,
1962
+ groups=resnet_groups,
1963
+ dropout=dropout,
1964
+ time_embedding_norm=resnet_time_scale_shift,
1965
+ non_linearity=resnet_act_fn,
1966
+ output_scale_factor=output_scale_factor,
1967
+ pre_norm=resnet_pre_norm,
1968
+ up=True,
1969
+ )
1970
+ ]
1971
+ )
1866
1972
  else:
1867
1973
  self.upsamplers = None
1868
1974
 
@@ -1878,7 +1984,10 @@ class AttnUpBlock2D(nn.Module):
1878
1984
 
1879
1985
  if self.upsamplers is not None:
1880
1986
  for upsampler in self.upsamplers:
1881
- hidden_states = upsampler(hidden_states)
1987
+ if self.upsample_type == "resnet":
1988
+ hidden_states = upsampler(hidden_states, temb=temb)
1989
+ else:
1990
+ hidden_states = upsampler(hidden_states)
1882
1991
 
1883
1992
  return hidden_states
1884
1993
 
@@ -1892,12 +2001,13 @@ class CrossAttnUpBlock2D(nn.Module):
1892
2001
  temb_channels: int,
1893
2002
  dropout: float = 0.0,
1894
2003
  num_layers: int = 1,
2004
+ transformer_layers_per_block: int = 1,
1895
2005
  resnet_eps: float = 1e-6,
1896
2006
  resnet_time_scale_shift: str = "default",
1897
2007
  resnet_act_fn: str = "swish",
1898
2008
  resnet_groups: int = 32,
1899
2009
  resnet_pre_norm: bool = True,
1900
- attn_num_head_channels=1,
2010
+ num_attention_heads=1,
1901
2011
  cross_attention_dim=1280,
1902
2012
  output_scale_factor=1.0,
1903
2013
  add_upsample=True,
@@ -1911,7 +2021,7 @@ class CrossAttnUpBlock2D(nn.Module):
1911
2021
  attentions = []
1912
2022
 
1913
2023
  self.has_cross_attention = True
1914
- self.attn_num_head_channels = attn_num_head_channels
2024
+ self.num_attention_heads = num_attention_heads
1915
2025
 
1916
2026
  for i in range(num_layers):
1917
2027
  res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -1934,10 +2044,10 @@ class CrossAttnUpBlock2D(nn.Module):
1934
2044
  if not dual_cross_attention:
1935
2045
  attentions.append(
1936
2046
  Transformer2DModel(
1937
- attn_num_head_channels,
1938
- out_channels // attn_num_head_channels,
2047
+ num_attention_heads,
2048
+ out_channels // num_attention_heads,
1939
2049
  in_channels=out_channels,
1940
- num_layers=1,
2050
+ num_layers=transformer_layers_per_block,
1941
2051
  cross_attention_dim=cross_attention_dim,
1942
2052
  norm_num_groups=resnet_groups,
1943
2053
  use_linear_projection=use_linear_projection,
@@ -1948,8 +2058,8 @@ class CrossAttnUpBlock2D(nn.Module):
1948
2058
  else:
1949
2059
  attentions.append(
1950
2060
  DualTransformer2DModel(
1951
- attn_num_head_channels,
1952
- out_channels // attn_num_head_channels,
2061
+ num_attention_heads,
2062
+ out_channels // num_attention_heads,
1953
2063
  in_channels=out_channels,
1954
2064
  num_layers=1,
1955
2065
  cross_attention_dim=cross_attention_dim,
@@ -2178,7 +2288,7 @@ class AttnUpDecoderBlock2D(nn.Module):
2178
2288
  resnet_act_fn: str = "swish",
2179
2289
  resnet_groups: int = 32,
2180
2290
  resnet_pre_norm: bool = True,
2181
- attn_num_head_channels=1,
2291
+ attention_head_dim=1,
2182
2292
  output_scale_factor=1.0,
2183
2293
  add_upsample=True,
2184
2294
  temb_channels=None,
@@ -2187,6 +2297,12 @@ class AttnUpDecoderBlock2D(nn.Module):
2187
2297
  resnets = []
2188
2298
  attentions = []
2189
2299
 
2300
+ if attention_head_dim is None:
2301
+ logger.warn(
2302
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
2303
+ )
2304
+ attention_head_dim = out_channels
2305
+
2190
2306
  for i in range(num_layers):
2191
2307
  input_channels = in_channels if i == 0 else out_channels
2192
2308
 
@@ -2207,8 +2323,8 @@ class AttnUpDecoderBlock2D(nn.Module):
2207
2323
  attentions.append(
2208
2324
  Attention(
2209
2325
  out_channels,
2210
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
2211
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
2326
+ heads=out_channels // attention_head_dim,
2327
+ dim_head=attention_head_dim,
2212
2328
  rescale_output_factor=output_scale_factor,
2213
2329
  eps=resnet_eps,
2214
2330
  norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
@@ -2253,9 +2369,8 @@ class AttnSkipUpBlock2D(nn.Module):
2253
2369
  resnet_time_scale_shift: str = "default",
2254
2370
  resnet_act_fn: str = "swish",
2255
2371
  resnet_pre_norm: bool = True,
2256
- attn_num_head_channels=1,
2372
+ attention_head_dim=1,
2257
2373
  output_scale_factor=np.sqrt(2.0),
2258
- upsample_padding=1,
2259
2374
  add_upsample=True,
2260
2375
  ):
2261
2376
  super().__init__()
@@ -2282,11 +2397,17 @@ class AttnSkipUpBlock2D(nn.Module):
2282
2397
  )
2283
2398
  )
2284
2399
 
2400
+ if attention_head_dim is None:
2401
+ logger.warn(
2402
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
2403
+ )
2404
+ attention_head_dim = out_channels
2405
+
2285
2406
  self.attentions.append(
2286
2407
  Attention(
2287
2408
  out_channels,
2288
- heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
2289
- dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
2409
+ heads=out_channels // attention_head_dim,
2410
+ dim_head=attention_head_dim,
2290
2411
  rescale_output_factor=output_scale_factor,
2291
2412
  eps=resnet_eps,
2292
2413
  norm_num_groups=32,
@@ -2563,7 +2684,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
2563
2684
  resnet_act_fn: str = "swish",
2564
2685
  resnet_groups: int = 32,
2565
2686
  resnet_pre_norm: bool = True,
2566
- attn_num_head_channels=1,
2687
+ attention_head_dim=1,
2567
2688
  cross_attention_dim=1280,
2568
2689
  output_scale_factor=1.0,
2569
2690
  add_upsample=True,
@@ -2576,9 +2697,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
2576
2697
  attentions = []
2577
2698
 
2578
2699
  self.has_cross_attention = True
2579
- self.attn_num_head_channels = attn_num_head_channels
2700
+ self.attention_head_dim = attention_head_dim
2580
2701
 
2581
- self.num_heads = out_channels // self.attn_num_head_channels
2702
+ self.num_heads = out_channels // self.attention_head_dim
2582
2703
 
2583
2704
  for i in range(num_layers):
2584
2705
  res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -2609,7 +2730,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
2609
2730
  query_dim=out_channels,
2610
2731
  cross_attention_dim=out_channels,
2611
2732
  heads=self.num_heads,
2612
- dim_head=attn_num_head_channels,
2733
+ dim_head=self.attention_head_dim,
2613
2734
  added_kv_proj_dim=cross_attention_dim,
2614
2735
  norm_num_groups=resnet_groups,
2615
2736
  bias=True,
@@ -2804,7 +2925,7 @@ class KCrossAttnUpBlock2D(nn.Module):
2804
2925
  resnet_eps: float = 1e-5,
2805
2926
  resnet_act_fn: str = "gelu",
2806
2927
  resnet_group_size: int = 32,
2807
- attn_num_head_channels=1, # attention dim_head
2928
+ attention_head_dim=1, # attention dim_head
2808
2929
  cross_attention_dim: int = 768,
2809
2930
  add_upsample: bool = True,
2810
2931
  upcast_attention: bool = False,
@@ -2818,7 +2939,7 @@ class KCrossAttnUpBlock2D(nn.Module):
2818
2939
  add_self_attention = True if is_first_block else False
2819
2940
 
2820
2941
  self.has_cross_attention = True
2821
- self.attn_num_head_channels = attn_num_head_channels
2942
+ self.attention_head_dim = attention_head_dim
2822
2943
 
2823
2944
  # in_channels, and out_channels for the block (k-unet)
2824
2945
  k_in_channels = out_channels if is_first_block else 2 * out_channels
@@ -2854,10 +2975,10 @@ class KCrossAttnUpBlock2D(nn.Module):
2854
2975
  attentions.append(
2855
2976
  KAttentionBlock(
2856
2977
  k_out_channels if (i == num_layers - 1) else out_channels,
2857
- k_out_channels // attn_num_head_channels
2978
+ k_out_channels // attention_head_dim
2858
2979
  if (i == num_layers - 1)
2859
- else out_channels // attn_num_head_channels,
2860
- attn_num_head_channels,
2980
+ else out_channels // attention_head_dim,
2981
+ attention_head_dim,
2861
2982
  cross_attention_dim=cross_attention_dim,
2862
2983
  temb_channels=temb_channels,
2863
2984
  attention_bias=True,