diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
import torch.nn.functional as F
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
-
from ..utils import is_torch_version
|
21
|
+
from ..utils import is_torch_version, logging
|
22
22
|
from .attention import AdaGroupNorm
|
23
23
|
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
24
24
|
from .dual_transformer_2d import DualTransformer2DModel
|
@@ -26,6 +26,9 @@ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D,
|
|
26
26
|
from .transformer_2d import Transformer2DModel
|
27
27
|
|
28
28
|
|
29
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
|
+
|
31
|
+
|
29
32
|
def get_down_block(
|
30
33
|
down_block_type,
|
31
34
|
num_layers,
|
@@ -35,7 +38,8 @@ def get_down_block(
|
|
35
38
|
add_downsample,
|
36
39
|
resnet_eps,
|
37
40
|
resnet_act_fn,
|
38
|
-
|
41
|
+
transformer_layers_per_block=1,
|
42
|
+
num_attention_heads=None,
|
39
43
|
resnet_groups=None,
|
40
44
|
cross_attention_dim=None,
|
41
45
|
downsample_padding=None,
|
@@ -47,7 +51,16 @@ def get_down_block(
|
|
47
51
|
resnet_skip_time_act=False,
|
48
52
|
resnet_out_scale_factor=1.0,
|
49
53
|
cross_attention_norm=None,
|
54
|
+
attention_head_dim=None,
|
55
|
+
downsample_type=None,
|
50
56
|
):
|
57
|
+
# If attn head dim is not defined, we default it to the number of heads
|
58
|
+
if attention_head_dim is None:
|
59
|
+
logger.warn(
|
60
|
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
61
|
+
)
|
62
|
+
attention_head_dim = num_attention_heads
|
63
|
+
|
51
64
|
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
52
65
|
if down_block_type == "DownBlock2D":
|
53
66
|
return DownBlock2D(
|
@@ -77,24 +90,29 @@ def get_down_block(
|
|
77
90
|
output_scale_factor=resnet_out_scale_factor,
|
78
91
|
)
|
79
92
|
elif down_block_type == "AttnDownBlock2D":
|
93
|
+
if add_downsample is False:
|
94
|
+
downsample_type = None
|
95
|
+
else:
|
96
|
+
downsample_type = downsample_type or "conv" # default to 'conv'
|
80
97
|
return AttnDownBlock2D(
|
81
98
|
num_layers=num_layers,
|
82
99
|
in_channels=in_channels,
|
83
100
|
out_channels=out_channels,
|
84
101
|
temb_channels=temb_channels,
|
85
|
-
add_downsample=add_downsample,
|
86
102
|
resnet_eps=resnet_eps,
|
87
103
|
resnet_act_fn=resnet_act_fn,
|
88
104
|
resnet_groups=resnet_groups,
|
89
105
|
downsample_padding=downsample_padding,
|
90
|
-
|
106
|
+
attention_head_dim=attention_head_dim,
|
91
107
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
108
|
+
downsample_type=downsample_type,
|
92
109
|
)
|
93
110
|
elif down_block_type == "CrossAttnDownBlock2D":
|
94
111
|
if cross_attention_dim is None:
|
95
112
|
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
96
113
|
return CrossAttnDownBlock2D(
|
97
114
|
num_layers=num_layers,
|
115
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
98
116
|
in_channels=in_channels,
|
99
117
|
out_channels=out_channels,
|
100
118
|
temb_channels=temb_channels,
|
@@ -104,7 +122,7 @@ def get_down_block(
|
|
104
122
|
resnet_groups=resnet_groups,
|
105
123
|
downsample_padding=downsample_padding,
|
106
124
|
cross_attention_dim=cross_attention_dim,
|
107
|
-
|
125
|
+
num_attention_heads=num_attention_heads,
|
108
126
|
dual_cross_attention=dual_cross_attention,
|
109
127
|
use_linear_projection=use_linear_projection,
|
110
128
|
only_cross_attention=only_cross_attention,
|
@@ -124,7 +142,7 @@ def get_down_block(
|
|
124
142
|
resnet_act_fn=resnet_act_fn,
|
125
143
|
resnet_groups=resnet_groups,
|
126
144
|
cross_attention_dim=cross_attention_dim,
|
127
|
-
|
145
|
+
attention_head_dim=attention_head_dim,
|
128
146
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
129
147
|
skip_time_act=resnet_skip_time_act,
|
130
148
|
output_scale_factor=resnet_out_scale_factor,
|
@@ -152,8 +170,7 @@ def get_down_block(
|
|
152
170
|
add_downsample=add_downsample,
|
153
171
|
resnet_eps=resnet_eps,
|
154
172
|
resnet_act_fn=resnet_act_fn,
|
155
|
-
|
156
|
-
attn_num_head_channels=attn_num_head_channels,
|
173
|
+
attention_head_dim=attention_head_dim,
|
157
174
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
158
175
|
)
|
159
176
|
elif down_block_type == "DownEncoderBlock2D":
|
@@ -178,7 +195,7 @@ def get_down_block(
|
|
178
195
|
resnet_act_fn=resnet_act_fn,
|
179
196
|
resnet_groups=resnet_groups,
|
180
197
|
downsample_padding=downsample_padding,
|
181
|
-
|
198
|
+
attention_head_dim=attention_head_dim,
|
182
199
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
183
200
|
)
|
184
201
|
elif down_block_type == "KDownBlock2D":
|
@@ -201,7 +218,7 @@ def get_down_block(
|
|
201
218
|
resnet_eps=resnet_eps,
|
202
219
|
resnet_act_fn=resnet_act_fn,
|
203
220
|
cross_attention_dim=cross_attention_dim,
|
204
|
-
|
221
|
+
attention_head_dim=attention_head_dim,
|
205
222
|
add_self_attention=True if not add_downsample else False,
|
206
223
|
)
|
207
224
|
raise ValueError(f"{down_block_type} does not exist.")
|
@@ -217,7 +234,8 @@ def get_up_block(
|
|
217
234
|
add_upsample,
|
218
235
|
resnet_eps,
|
219
236
|
resnet_act_fn,
|
220
|
-
|
237
|
+
transformer_layers_per_block=1,
|
238
|
+
num_attention_heads=None,
|
221
239
|
resnet_groups=None,
|
222
240
|
cross_attention_dim=None,
|
223
241
|
dual_cross_attention=False,
|
@@ -228,7 +246,16 @@ def get_up_block(
|
|
228
246
|
resnet_skip_time_act=False,
|
229
247
|
resnet_out_scale_factor=1.0,
|
230
248
|
cross_attention_norm=None,
|
249
|
+
attention_head_dim=None,
|
250
|
+
upsample_type=None,
|
231
251
|
):
|
252
|
+
# If attn head dim is not defined, we default it to the number of heads
|
253
|
+
if attention_head_dim is None:
|
254
|
+
logger.warn(
|
255
|
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
256
|
+
)
|
257
|
+
attention_head_dim = num_attention_heads
|
258
|
+
|
232
259
|
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
233
260
|
if up_block_type == "UpBlock2D":
|
234
261
|
return UpBlock2D(
|
@@ -263,6 +290,7 @@ def get_up_block(
|
|
263
290
|
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
264
291
|
return CrossAttnUpBlock2D(
|
265
292
|
num_layers=num_layers,
|
293
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
266
294
|
in_channels=in_channels,
|
267
295
|
out_channels=out_channels,
|
268
296
|
prev_output_channel=prev_output_channel,
|
@@ -272,7 +300,7 @@ def get_up_block(
|
|
272
300
|
resnet_act_fn=resnet_act_fn,
|
273
301
|
resnet_groups=resnet_groups,
|
274
302
|
cross_attention_dim=cross_attention_dim,
|
275
|
-
|
303
|
+
num_attention_heads=num_attention_heads,
|
276
304
|
dual_cross_attention=dual_cross_attention,
|
277
305
|
use_linear_projection=use_linear_projection,
|
278
306
|
only_cross_attention=only_cross_attention,
|
@@ -293,7 +321,7 @@ def get_up_block(
|
|
293
321
|
resnet_act_fn=resnet_act_fn,
|
294
322
|
resnet_groups=resnet_groups,
|
295
323
|
cross_attention_dim=cross_attention_dim,
|
296
|
-
|
324
|
+
attention_head_dim=attention_head_dim,
|
297
325
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
298
326
|
skip_time_act=resnet_skip_time_act,
|
299
327
|
output_scale_factor=resnet_out_scale_factor,
|
@@ -301,18 +329,23 @@ def get_up_block(
|
|
301
329
|
cross_attention_norm=cross_attention_norm,
|
302
330
|
)
|
303
331
|
elif up_block_type == "AttnUpBlock2D":
|
332
|
+
if add_upsample is False:
|
333
|
+
upsample_type = None
|
334
|
+
else:
|
335
|
+
upsample_type = upsample_type or "conv" # default to 'conv'
|
336
|
+
|
304
337
|
return AttnUpBlock2D(
|
305
338
|
num_layers=num_layers,
|
306
339
|
in_channels=in_channels,
|
307
340
|
out_channels=out_channels,
|
308
341
|
prev_output_channel=prev_output_channel,
|
309
342
|
temb_channels=temb_channels,
|
310
|
-
add_upsample=add_upsample,
|
311
343
|
resnet_eps=resnet_eps,
|
312
344
|
resnet_act_fn=resnet_act_fn,
|
313
345
|
resnet_groups=resnet_groups,
|
314
|
-
|
346
|
+
attention_head_dim=attention_head_dim,
|
315
347
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
348
|
+
upsample_type=upsample_type,
|
316
349
|
)
|
317
350
|
elif up_block_type == "SkipUpBlock2D":
|
318
351
|
return SkipUpBlock2D(
|
@@ -336,7 +369,7 @@ def get_up_block(
|
|
336
369
|
add_upsample=add_upsample,
|
337
370
|
resnet_eps=resnet_eps,
|
338
371
|
resnet_act_fn=resnet_act_fn,
|
339
|
-
|
372
|
+
attention_head_dim=attention_head_dim,
|
340
373
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
341
374
|
)
|
342
375
|
elif up_block_type == "UpDecoderBlock2D":
|
@@ -360,7 +393,7 @@ def get_up_block(
|
|
360
393
|
resnet_eps=resnet_eps,
|
361
394
|
resnet_act_fn=resnet_act_fn,
|
362
395
|
resnet_groups=resnet_groups,
|
363
|
-
|
396
|
+
attention_head_dim=attention_head_dim,
|
364
397
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
365
398
|
temb_channels=temb_channels,
|
366
399
|
)
|
@@ -384,7 +417,7 @@ def get_up_block(
|
|
384
417
|
resnet_eps=resnet_eps,
|
385
418
|
resnet_act_fn=resnet_act_fn,
|
386
419
|
cross_attention_dim=cross_attention_dim,
|
387
|
-
|
420
|
+
attention_head_dim=attention_head_dim,
|
388
421
|
)
|
389
422
|
|
390
423
|
raise ValueError(f"{up_block_type} does not exist.")
|
@@ -403,7 +436,7 @@ class UNetMidBlock2D(nn.Module):
|
|
403
436
|
resnet_groups: int = 32,
|
404
437
|
resnet_pre_norm: bool = True,
|
405
438
|
add_attention: bool = True,
|
406
|
-
|
439
|
+
attention_head_dim=1,
|
407
440
|
output_scale_factor=1.0,
|
408
441
|
):
|
409
442
|
super().__init__()
|
@@ -427,13 +460,19 @@ class UNetMidBlock2D(nn.Module):
|
|
427
460
|
]
|
428
461
|
attentions = []
|
429
462
|
|
463
|
+
if attention_head_dim is None:
|
464
|
+
logger.warn(
|
465
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
466
|
+
)
|
467
|
+
attention_head_dim = in_channels
|
468
|
+
|
430
469
|
for _ in range(num_layers):
|
431
470
|
if self.add_attention:
|
432
471
|
attentions.append(
|
433
472
|
Attention(
|
434
473
|
in_channels,
|
435
|
-
heads=in_channels //
|
436
|
-
dim_head=
|
474
|
+
heads=in_channels // attention_head_dim,
|
475
|
+
dim_head=attention_head_dim,
|
437
476
|
rescale_output_factor=output_scale_factor,
|
438
477
|
eps=resnet_eps,
|
439
478
|
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
|
@@ -482,12 +521,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
482
521
|
temb_channels: int,
|
483
522
|
dropout: float = 0.0,
|
484
523
|
num_layers: int = 1,
|
524
|
+
transformer_layers_per_block: int = 1,
|
485
525
|
resnet_eps: float = 1e-6,
|
486
526
|
resnet_time_scale_shift: str = "default",
|
487
527
|
resnet_act_fn: str = "swish",
|
488
528
|
resnet_groups: int = 32,
|
489
529
|
resnet_pre_norm: bool = True,
|
490
|
-
|
530
|
+
num_attention_heads=1,
|
491
531
|
output_scale_factor=1.0,
|
492
532
|
cross_attention_dim=1280,
|
493
533
|
dual_cross_attention=False,
|
@@ -497,7 +537,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
497
537
|
super().__init__()
|
498
538
|
|
499
539
|
self.has_cross_attention = True
|
500
|
-
self.
|
540
|
+
self.num_attention_heads = num_attention_heads
|
501
541
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
502
542
|
|
503
543
|
# there is always at least one resnet
|
@@ -521,10 +561,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
521
561
|
if not dual_cross_attention:
|
522
562
|
attentions.append(
|
523
563
|
Transformer2DModel(
|
524
|
-
|
525
|
-
in_channels //
|
564
|
+
num_attention_heads,
|
565
|
+
in_channels // num_attention_heads,
|
526
566
|
in_channels=in_channels,
|
527
|
-
num_layers=
|
567
|
+
num_layers=transformer_layers_per_block,
|
528
568
|
cross_attention_dim=cross_attention_dim,
|
529
569
|
norm_num_groups=resnet_groups,
|
530
570
|
use_linear_projection=use_linear_projection,
|
@@ -534,8 +574,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
534
574
|
else:
|
535
575
|
attentions.append(
|
536
576
|
DualTransformer2DModel(
|
537
|
-
|
538
|
-
in_channels //
|
577
|
+
num_attention_heads,
|
578
|
+
in_channels // num_attention_heads,
|
539
579
|
in_channels=in_channels,
|
540
580
|
num_layers=1,
|
541
581
|
cross_attention_dim=cross_attention_dim,
|
@@ -596,7 +636,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
596
636
|
resnet_act_fn: str = "swish",
|
597
637
|
resnet_groups: int = 32,
|
598
638
|
resnet_pre_norm: bool = True,
|
599
|
-
|
639
|
+
attention_head_dim=1,
|
600
640
|
output_scale_factor=1.0,
|
601
641
|
cross_attention_dim=1280,
|
602
642
|
skip_time_act=False,
|
@@ -607,10 +647,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
607
647
|
|
608
648
|
self.has_cross_attention = True
|
609
649
|
|
610
|
-
self.
|
650
|
+
self.attention_head_dim = attention_head_dim
|
611
651
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
612
652
|
|
613
|
-
self.num_heads = in_channels // self.
|
653
|
+
self.num_heads = in_channels // self.attention_head_dim
|
614
654
|
|
615
655
|
# there is always at least one resnet
|
616
656
|
resnets = [
|
@@ -640,7 +680,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
640
680
|
query_dim=in_channels,
|
641
681
|
cross_attention_dim=in_channels,
|
642
682
|
heads=self.num_heads,
|
643
|
-
dim_head=
|
683
|
+
dim_head=self.attention_head_dim,
|
644
684
|
added_kv_proj_dim=cross_attention_dim,
|
645
685
|
norm_num_groups=resnet_groups,
|
646
686
|
bias=True,
|
@@ -720,14 +760,21 @@ class AttnDownBlock2D(nn.Module):
|
|
720
760
|
resnet_act_fn: str = "swish",
|
721
761
|
resnet_groups: int = 32,
|
722
762
|
resnet_pre_norm: bool = True,
|
723
|
-
|
763
|
+
attention_head_dim=1,
|
724
764
|
output_scale_factor=1.0,
|
725
765
|
downsample_padding=1,
|
726
|
-
|
766
|
+
downsample_type="conv",
|
727
767
|
):
|
728
768
|
super().__init__()
|
729
769
|
resnets = []
|
730
770
|
attentions = []
|
771
|
+
self.downsample_type = downsample_type
|
772
|
+
|
773
|
+
if attention_head_dim is None:
|
774
|
+
logger.warn(
|
775
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
776
|
+
)
|
777
|
+
attention_head_dim = out_channels
|
731
778
|
|
732
779
|
for i in range(num_layers):
|
733
780
|
in_channels = in_channels if i == 0 else out_channels
|
@@ -748,8 +795,8 @@ class AttnDownBlock2D(nn.Module):
|
|
748
795
|
attentions.append(
|
749
796
|
Attention(
|
750
797
|
out_channels,
|
751
|
-
heads=out_channels //
|
752
|
-
dim_head=
|
798
|
+
heads=out_channels // attention_head_dim,
|
799
|
+
dim_head=attention_head_dim,
|
753
800
|
rescale_output_factor=output_scale_factor,
|
754
801
|
eps=resnet_eps,
|
755
802
|
norm_num_groups=resnet_groups,
|
@@ -763,7 +810,7 @@ class AttnDownBlock2D(nn.Module):
|
|
763
810
|
self.attentions = nn.ModuleList(attentions)
|
764
811
|
self.resnets = nn.ModuleList(resnets)
|
765
812
|
|
766
|
-
if
|
813
|
+
if downsample_type == "conv":
|
767
814
|
self.downsamplers = nn.ModuleList(
|
768
815
|
[
|
769
816
|
Downsample2D(
|
@@ -771,6 +818,24 @@ class AttnDownBlock2D(nn.Module):
|
|
771
818
|
)
|
772
819
|
]
|
773
820
|
)
|
821
|
+
elif downsample_type == "resnet":
|
822
|
+
self.downsamplers = nn.ModuleList(
|
823
|
+
[
|
824
|
+
ResnetBlock2D(
|
825
|
+
in_channels=out_channels,
|
826
|
+
out_channels=out_channels,
|
827
|
+
temb_channels=temb_channels,
|
828
|
+
eps=resnet_eps,
|
829
|
+
groups=resnet_groups,
|
830
|
+
dropout=dropout,
|
831
|
+
time_embedding_norm=resnet_time_scale_shift,
|
832
|
+
non_linearity=resnet_act_fn,
|
833
|
+
output_scale_factor=output_scale_factor,
|
834
|
+
pre_norm=resnet_pre_norm,
|
835
|
+
down=True,
|
836
|
+
)
|
837
|
+
]
|
838
|
+
)
|
774
839
|
else:
|
775
840
|
self.downsamplers = None
|
776
841
|
|
@@ -780,11 +845,14 @@ class AttnDownBlock2D(nn.Module):
|
|
780
845
|
for resnet, attn in zip(self.resnets, self.attentions):
|
781
846
|
hidden_states = resnet(hidden_states, temb)
|
782
847
|
hidden_states = attn(hidden_states)
|
783
|
-
output_states
|
848
|
+
output_states = output_states + (hidden_states,)
|
784
849
|
|
785
850
|
if self.downsamplers is not None:
|
786
851
|
for downsampler in self.downsamplers:
|
787
|
-
|
852
|
+
if self.downsample_type == "resnet":
|
853
|
+
hidden_states = downsampler(hidden_states, temb=temb)
|
854
|
+
else:
|
855
|
+
hidden_states = downsampler(hidden_states)
|
788
856
|
|
789
857
|
output_states += (hidden_states,)
|
790
858
|
|
@@ -799,12 +867,13 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
799
867
|
temb_channels: int,
|
800
868
|
dropout: float = 0.0,
|
801
869
|
num_layers: int = 1,
|
870
|
+
transformer_layers_per_block: int = 1,
|
802
871
|
resnet_eps: float = 1e-6,
|
803
872
|
resnet_time_scale_shift: str = "default",
|
804
873
|
resnet_act_fn: str = "swish",
|
805
874
|
resnet_groups: int = 32,
|
806
875
|
resnet_pre_norm: bool = True,
|
807
|
-
|
876
|
+
num_attention_heads=1,
|
808
877
|
cross_attention_dim=1280,
|
809
878
|
output_scale_factor=1.0,
|
810
879
|
downsample_padding=1,
|
@@ -819,7 +888,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
819
888
|
attentions = []
|
820
889
|
|
821
890
|
self.has_cross_attention = True
|
822
|
-
self.
|
891
|
+
self.num_attention_heads = num_attention_heads
|
823
892
|
|
824
893
|
for i in range(num_layers):
|
825
894
|
in_channels = in_channels if i == 0 else out_channels
|
@@ -840,10 +909,10 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
840
909
|
if not dual_cross_attention:
|
841
910
|
attentions.append(
|
842
911
|
Transformer2DModel(
|
843
|
-
|
844
|
-
out_channels //
|
912
|
+
num_attention_heads,
|
913
|
+
out_channels // num_attention_heads,
|
845
914
|
in_channels=out_channels,
|
846
|
-
num_layers=
|
915
|
+
num_layers=transformer_layers_per_block,
|
847
916
|
cross_attention_dim=cross_attention_dim,
|
848
917
|
norm_num_groups=resnet_groups,
|
849
918
|
use_linear_projection=use_linear_projection,
|
@@ -854,8 +923,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
854
923
|
else:
|
855
924
|
attentions.append(
|
856
925
|
DualTransformer2DModel(
|
857
|
-
|
858
|
-
out_channels //
|
926
|
+
num_attention_heads,
|
927
|
+
out_channels // num_attention_heads,
|
859
928
|
in_channels=out_channels,
|
860
929
|
num_layers=1,
|
861
930
|
cross_attention_dim=cross_attention_dim,
|
@@ -1099,7 +1168,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1099
1168
|
resnet_act_fn: str = "swish",
|
1100
1169
|
resnet_groups: int = 32,
|
1101
1170
|
resnet_pre_norm: bool = True,
|
1102
|
-
|
1171
|
+
attention_head_dim=1,
|
1103
1172
|
output_scale_factor=1.0,
|
1104
1173
|
add_downsample=True,
|
1105
1174
|
downsample_padding=1,
|
@@ -1108,6 +1177,12 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1108
1177
|
resnets = []
|
1109
1178
|
attentions = []
|
1110
1179
|
|
1180
|
+
if attention_head_dim is None:
|
1181
|
+
logger.warn(
|
1182
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
1183
|
+
)
|
1184
|
+
attention_head_dim = out_channels
|
1185
|
+
|
1111
1186
|
for i in range(num_layers):
|
1112
1187
|
in_channels = in_channels if i == 0 else out_channels
|
1113
1188
|
resnets.append(
|
@@ -1127,8 +1202,8 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1127
1202
|
attentions.append(
|
1128
1203
|
Attention(
|
1129
1204
|
out_channels,
|
1130
|
-
heads=out_channels //
|
1131
|
-
dim_head=
|
1205
|
+
heads=out_channels // attention_head_dim,
|
1206
|
+
dim_head=attention_head_dim,
|
1132
1207
|
rescale_output_factor=output_scale_factor,
|
1133
1208
|
eps=resnet_eps,
|
1134
1209
|
norm_num_groups=resnet_groups,
|
@@ -1177,15 +1252,20 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1177
1252
|
resnet_time_scale_shift: str = "default",
|
1178
1253
|
resnet_act_fn: str = "swish",
|
1179
1254
|
resnet_pre_norm: bool = True,
|
1180
|
-
|
1255
|
+
attention_head_dim=1,
|
1181
1256
|
output_scale_factor=np.sqrt(2.0),
|
1182
|
-
downsample_padding=1,
|
1183
1257
|
add_downsample=True,
|
1184
1258
|
):
|
1185
1259
|
super().__init__()
|
1186
1260
|
self.attentions = nn.ModuleList([])
|
1187
1261
|
self.resnets = nn.ModuleList([])
|
1188
1262
|
|
1263
|
+
if attention_head_dim is None:
|
1264
|
+
logger.warn(
|
1265
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
1266
|
+
)
|
1267
|
+
attention_head_dim = out_channels
|
1268
|
+
|
1189
1269
|
for i in range(num_layers):
|
1190
1270
|
in_channels = in_channels if i == 0 else out_channels
|
1191
1271
|
self.resnets.append(
|
@@ -1206,8 +1286,8 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1206
1286
|
self.attentions.append(
|
1207
1287
|
Attention(
|
1208
1288
|
out_channels,
|
1209
|
-
heads=out_channels //
|
1210
|
-
dim_head=
|
1289
|
+
heads=out_channels // attention_head_dim,
|
1290
|
+
dim_head=attention_head_dim,
|
1211
1291
|
rescale_output_factor=output_scale_factor,
|
1212
1292
|
eps=resnet_eps,
|
1213
1293
|
norm_num_groups=32,
|
@@ -1451,7 +1531,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1451
1531
|
resnet_act_fn: str = "swish",
|
1452
1532
|
resnet_groups: int = 32,
|
1453
1533
|
resnet_pre_norm: bool = True,
|
1454
|
-
|
1534
|
+
attention_head_dim=1,
|
1455
1535
|
cross_attention_dim=1280,
|
1456
1536
|
output_scale_factor=1.0,
|
1457
1537
|
add_downsample=True,
|
@@ -1466,8 +1546,8 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1466
1546
|
resnets = []
|
1467
1547
|
attentions = []
|
1468
1548
|
|
1469
|
-
self.
|
1470
|
-
self.num_heads = out_channels // self.
|
1549
|
+
self.attention_head_dim = attention_head_dim
|
1550
|
+
self.num_heads = out_channels // self.attention_head_dim
|
1471
1551
|
|
1472
1552
|
for i in range(num_layers):
|
1473
1553
|
in_channels = in_channels if i == 0 else out_channels
|
@@ -1496,7 +1576,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1496
1576
|
query_dim=out_channels,
|
1497
1577
|
cross_attention_dim=out_channels,
|
1498
1578
|
heads=self.num_heads,
|
1499
|
-
dim_head=
|
1579
|
+
dim_head=attention_head_dim,
|
1500
1580
|
added_kv_proj_dim=cross_attention_dim,
|
1501
1581
|
norm_num_groups=resnet_groups,
|
1502
1582
|
bias=True,
|
@@ -1686,7 +1766,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
1686
1766
|
num_layers: int = 4,
|
1687
1767
|
resnet_group_size: int = 32,
|
1688
1768
|
add_downsample=True,
|
1689
|
-
|
1769
|
+
attention_head_dim: int = 64,
|
1690
1770
|
add_self_attention: bool = False,
|
1691
1771
|
resnet_eps: float = 1e-5,
|
1692
1772
|
resnet_act_fn: str = "gelu",
|
@@ -1719,8 +1799,8 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
1719
1799
|
attentions.append(
|
1720
1800
|
KAttentionBlock(
|
1721
1801
|
out_channels,
|
1722
|
-
out_channels //
|
1723
|
-
|
1802
|
+
out_channels // attention_head_dim,
|
1803
|
+
attention_head_dim,
|
1724
1804
|
cross_attention_dim=cross_attention_dim,
|
1725
1805
|
temb_channels=temb_channels,
|
1726
1806
|
attention_bias=True,
|
@@ -1817,14 +1897,22 @@ class AttnUpBlock2D(nn.Module):
|
|
1817
1897
|
resnet_act_fn: str = "swish",
|
1818
1898
|
resnet_groups: int = 32,
|
1819
1899
|
resnet_pre_norm: bool = True,
|
1820
|
-
|
1900
|
+
attention_head_dim=1,
|
1821
1901
|
output_scale_factor=1.0,
|
1822
|
-
|
1902
|
+
upsample_type="conv",
|
1823
1903
|
):
|
1824
1904
|
super().__init__()
|
1825
1905
|
resnets = []
|
1826
1906
|
attentions = []
|
1827
1907
|
|
1908
|
+
self.upsample_type = upsample_type
|
1909
|
+
|
1910
|
+
if attention_head_dim is None:
|
1911
|
+
logger.warn(
|
1912
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
1913
|
+
)
|
1914
|
+
attention_head_dim = out_channels
|
1915
|
+
|
1828
1916
|
for i in range(num_layers):
|
1829
1917
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1830
1918
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
@@ -1846,8 +1934,8 @@ class AttnUpBlock2D(nn.Module):
|
|
1846
1934
|
attentions.append(
|
1847
1935
|
Attention(
|
1848
1936
|
out_channels,
|
1849
|
-
heads=out_channels //
|
1850
|
-
dim_head=
|
1937
|
+
heads=out_channels // attention_head_dim,
|
1938
|
+
dim_head=attention_head_dim,
|
1851
1939
|
rescale_output_factor=output_scale_factor,
|
1852
1940
|
eps=resnet_eps,
|
1853
1941
|
norm_num_groups=resnet_groups,
|
@@ -1861,8 +1949,26 @@ class AttnUpBlock2D(nn.Module):
|
|
1861
1949
|
self.attentions = nn.ModuleList(attentions)
|
1862
1950
|
self.resnets = nn.ModuleList(resnets)
|
1863
1951
|
|
1864
|
-
if
|
1952
|
+
if upsample_type == "conv":
|
1865
1953
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1954
|
+
elif upsample_type == "resnet":
|
1955
|
+
self.upsamplers = nn.ModuleList(
|
1956
|
+
[
|
1957
|
+
ResnetBlock2D(
|
1958
|
+
in_channels=out_channels,
|
1959
|
+
out_channels=out_channels,
|
1960
|
+
temb_channels=temb_channels,
|
1961
|
+
eps=resnet_eps,
|
1962
|
+
groups=resnet_groups,
|
1963
|
+
dropout=dropout,
|
1964
|
+
time_embedding_norm=resnet_time_scale_shift,
|
1965
|
+
non_linearity=resnet_act_fn,
|
1966
|
+
output_scale_factor=output_scale_factor,
|
1967
|
+
pre_norm=resnet_pre_norm,
|
1968
|
+
up=True,
|
1969
|
+
)
|
1970
|
+
]
|
1971
|
+
)
|
1866
1972
|
else:
|
1867
1973
|
self.upsamplers = None
|
1868
1974
|
|
@@ -1878,7 +1984,10 @@ class AttnUpBlock2D(nn.Module):
|
|
1878
1984
|
|
1879
1985
|
if self.upsamplers is not None:
|
1880
1986
|
for upsampler in self.upsamplers:
|
1881
|
-
|
1987
|
+
if self.upsample_type == "resnet":
|
1988
|
+
hidden_states = upsampler(hidden_states, temb=temb)
|
1989
|
+
else:
|
1990
|
+
hidden_states = upsampler(hidden_states)
|
1882
1991
|
|
1883
1992
|
return hidden_states
|
1884
1993
|
|
@@ -1892,12 +2001,13 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1892
2001
|
temb_channels: int,
|
1893
2002
|
dropout: float = 0.0,
|
1894
2003
|
num_layers: int = 1,
|
2004
|
+
transformer_layers_per_block: int = 1,
|
1895
2005
|
resnet_eps: float = 1e-6,
|
1896
2006
|
resnet_time_scale_shift: str = "default",
|
1897
2007
|
resnet_act_fn: str = "swish",
|
1898
2008
|
resnet_groups: int = 32,
|
1899
2009
|
resnet_pre_norm: bool = True,
|
1900
|
-
|
2010
|
+
num_attention_heads=1,
|
1901
2011
|
cross_attention_dim=1280,
|
1902
2012
|
output_scale_factor=1.0,
|
1903
2013
|
add_upsample=True,
|
@@ -1911,7 +2021,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1911
2021
|
attentions = []
|
1912
2022
|
|
1913
2023
|
self.has_cross_attention = True
|
1914
|
-
self.
|
2024
|
+
self.num_attention_heads = num_attention_heads
|
1915
2025
|
|
1916
2026
|
for i in range(num_layers):
|
1917
2027
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
@@ -1934,10 +2044,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1934
2044
|
if not dual_cross_attention:
|
1935
2045
|
attentions.append(
|
1936
2046
|
Transformer2DModel(
|
1937
|
-
|
1938
|
-
out_channels //
|
2047
|
+
num_attention_heads,
|
2048
|
+
out_channels // num_attention_heads,
|
1939
2049
|
in_channels=out_channels,
|
1940
|
-
num_layers=
|
2050
|
+
num_layers=transformer_layers_per_block,
|
1941
2051
|
cross_attention_dim=cross_attention_dim,
|
1942
2052
|
norm_num_groups=resnet_groups,
|
1943
2053
|
use_linear_projection=use_linear_projection,
|
@@ -1948,8 +2058,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1948
2058
|
else:
|
1949
2059
|
attentions.append(
|
1950
2060
|
DualTransformer2DModel(
|
1951
|
-
|
1952
|
-
out_channels //
|
2061
|
+
num_attention_heads,
|
2062
|
+
out_channels // num_attention_heads,
|
1953
2063
|
in_channels=out_channels,
|
1954
2064
|
num_layers=1,
|
1955
2065
|
cross_attention_dim=cross_attention_dim,
|
@@ -2178,7 +2288,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2178
2288
|
resnet_act_fn: str = "swish",
|
2179
2289
|
resnet_groups: int = 32,
|
2180
2290
|
resnet_pre_norm: bool = True,
|
2181
|
-
|
2291
|
+
attention_head_dim=1,
|
2182
2292
|
output_scale_factor=1.0,
|
2183
2293
|
add_upsample=True,
|
2184
2294
|
temb_channels=None,
|
@@ -2187,6 +2297,12 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2187
2297
|
resnets = []
|
2188
2298
|
attentions = []
|
2189
2299
|
|
2300
|
+
if attention_head_dim is None:
|
2301
|
+
logger.warn(
|
2302
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
2303
|
+
)
|
2304
|
+
attention_head_dim = out_channels
|
2305
|
+
|
2190
2306
|
for i in range(num_layers):
|
2191
2307
|
input_channels = in_channels if i == 0 else out_channels
|
2192
2308
|
|
@@ -2207,8 +2323,8 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2207
2323
|
attentions.append(
|
2208
2324
|
Attention(
|
2209
2325
|
out_channels,
|
2210
|
-
heads=out_channels //
|
2211
|
-
dim_head=
|
2326
|
+
heads=out_channels // attention_head_dim,
|
2327
|
+
dim_head=attention_head_dim,
|
2212
2328
|
rescale_output_factor=output_scale_factor,
|
2213
2329
|
eps=resnet_eps,
|
2214
2330
|
norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
|
@@ -2253,9 +2369,8 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2253
2369
|
resnet_time_scale_shift: str = "default",
|
2254
2370
|
resnet_act_fn: str = "swish",
|
2255
2371
|
resnet_pre_norm: bool = True,
|
2256
|
-
|
2372
|
+
attention_head_dim=1,
|
2257
2373
|
output_scale_factor=np.sqrt(2.0),
|
2258
|
-
upsample_padding=1,
|
2259
2374
|
add_upsample=True,
|
2260
2375
|
):
|
2261
2376
|
super().__init__()
|
@@ -2282,11 +2397,17 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2282
2397
|
)
|
2283
2398
|
)
|
2284
2399
|
|
2400
|
+
if attention_head_dim is None:
|
2401
|
+
logger.warn(
|
2402
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
2403
|
+
)
|
2404
|
+
attention_head_dim = out_channels
|
2405
|
+
|
2285
2406
|
self.attentions.append(
|
2286
2407
|
Attention(
|
2287
2408
|
out_channels,
|
2288
|
-
heads=out_channels //
|
2289
|
-
dim_head=
|
2409
|
+
heads=out_channels // attention_head_dim,
|
2410
|
+
dim_head=attention_head_dim,
|
2290
2411
|
rescale_output_factor=output_scale_factor,
|
2291
2412
|
eps=resnet_eps,
|
2292
2413
|
norm_num_groups=32,
|
@@ -2563,7 +2684,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
2563
2684
|
resnet_act_fn: str = "swish",
|
2564
2685
|
resnet_groups: int = 32,
|
2565
2686
|
resnet_pre_norm: bool = True,
|
2566
|
-
|
2687
|
+
attention_head_dim=1,
|
2567
2688
|
cross_attention_dim=1280,
|
2568
2689
|
output_scale_factor=1.0,
|
2569
2690
|
add_upsample=True,
|
@@ -2576,9 +2697,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
2576
2697
|
attentions = []
|
2577
2698
|
|
2578
2699
|
self.has_cross_attention = True
|
2579
|
-
self.
|
2700
|
+
self.attention_head_dim = attention_head_dim
|
2580
2701
|
|
2581
|
-
self.num_heads = out_channels // self.
|
2702
|
+
self.num_heads = out_channels // self.attention_head_dim
|
2582
2703
|
|
2583
2704
|
for i in range(num_layers):
|
2584
2705
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
@@ -2609,7 +2730,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
2609
2730
|
query_dim=out_channels,
|
2610
2731
|
cross_attention_dim=out_channels,
|
2611
2732
|
heads=self.num_heads,
|
2612
|
-
dim_head=
|
2733
|
+
dim_head=self.attention_head_dim,
|
2613
2734
|
added_kv_proj_dim=cross_attention_dim,
|
2614
2735
|
norm_num_groups=resnet_groups,
|
2615
2736
|
bias=True,
|
@@ -2804,7 +2925,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
2804
2925
|
resnet_eps: float = 1e-5,
|
2805
2926
|
resnet_act_fn: str = "gelu",
|
2806
2927
|
resnet_group_size: int = 32,
|
2807
|
-
|
2928
|
+
attention_head_dim=1, # attention dim_head
|
2808
2929
|
cross_attention_dim: int = 768,
|
2809
2930
|
add_upsample: bool = True,
|
2810
2931
|
upcast_attention: bool = False,
|
@@ -2818,7 +2939,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
2818
2939
|
add_self_attention = True if is_first_block else False
|
2819
2940
|
|
2820
2941
|
self.has_cross_attention = True
|
2821
|
-
self.
|
2942
|
+
self.attention_head_dim = attention_head_dim
|
2822
2943
|
|
2823
2944
|
# in_channels, and out_channels for the block (k-unet)
|
2824
2945
|
k_in_channels = out_channels if is_first_block else 2 * out_channels
|
@@ -2854,10 +2975,10 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
2854
2975
|
attentions.append(
|
2855
2976
|
KAttentionBlock(
|
2856
2977
|
k_out_channels if (i == num_layers - 1) else out_channels,
|
2857
|
-
k_out_channels //
|
2978
|
+
k_out_channels // attention_head_dim
|
2858
2979
|
if (i == num_layers - 1)
|
2859
|
-
else out_channels //
|
2860
|
-
|
2980
|
+
else out_channels // attention_head_dim,
|
2981
|
+
attention_head_dim,
|
2861
2982
|
cross_attention_dim=cross_attention_dim,
|
2862
2983
|
temb_channels=temb_channels,
|
2863
2984
|
attention_bias=True,
|