diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,9 @@ from ...models.attention_processor import (
|
|
18
18
|
from ...models.dual_transformer_2d import DualTransformer2DModel
|
19
19
|
from ...models.embeddings import (
|
20
20
|
GaussianFourierProjection,
|
21
|
+
ImageHintTimeEmbedding,
|
22
|
+
ImageProjection,
|
23
|
+
ImageTimeEmbedding,
|
21
24
|
TextImageProjection,
|
22
25
|
TextImageTimeEmbedding,
|
23
26
|
TextTimeEmbedding,
|
@@ -41,7 +44,7 @@ def get_down_block(
|
|
41
44
|
add_downsample,
|
42
45
|
resnet_eps,
|
43
46
|
resnet_act_fn,
|
44
|
-
|
47
|
+
num_attention_heads,
|
45
48
|
resnet_groups=None,
|
46
49
|
cross_attention_dim=None,
|
47
50
|
downsample_padding=None,
|
@@ -82,7 +85,7 @@ def get_down_block(
|
|
82
85
|
resnet_groups=resnet_groups,
|
83
86
|
downsample_padding=downsample_padding,
|
84
87
|
cross_attention_dim=cross_attention_dim,
|
85
|
-
|
88
|
+
num_attention_heads=num_attention_heads,
|
86
89
|
dual_cross_attention=dual_cross_attention,
|
87
90
|
use_linear_projection=use_linear_projection,
|
88
91
|
only_cross_attention=only_cross_attention,
|
@@ -101,7 +104,7 @@ def get_up_block(
|
|
101
104
|
add_upsample,
|
102
105
|
resnet_eps,
|
103
106
|
resnet_act_fn,
|
104
|
-
|
107
|
+
num_attention_heads,
|
105
108
|
resnet_groups=None,
|
106
109
|
cross_attention_dim=None,
|
107
110
|
dual_cross_attention=False,
|
@@ -141,7 +144,7 @@ def get_up_block(
|
|
141
144
|
resnet_act_fn=resnet_act_fn,
|
142
145
|
resnet_groups=resnet_groups,
|
143
146
|
cross_attention_dim=cross_attention_dim,
|
144
|
-
|
147
|
+
num_attention_heads=num_attention_heads,
|
145
148
|
dual_cross_attention=dual_cross_attention,
|
146
149
|
use_linear_projection=use_linear_projection,
|
147
150
|
only_cross_attention=only_cross_attention,
|
@@ -153,17 +156,17 @@ def get_up_block(
|
|
153
156
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
154
157
|
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
155
158
|
r"""
|
156
|
-
|
157
|
-
|
159
|
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
160
|
+
shaped output.
|
158
161
|
|
159
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
160
|
-
|
162
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
163
|
+
for all models (such as downloading or saving).
|
161
164
|
|
162
165
|
Parameters:
|
163
166
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
164
167
|
Height and width of input/output sample.
|
165
|
-
in_channels (`int`, *optional*, defaults to 4):
|
166
|
-
out_channels (`int`, *optional*, defaults to 4):
|
168
|
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
169
|
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
167
170
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
168
171
|
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
169
172
|
Whether to flip the sin to cos in the time embedding.
|
@@ -171,9 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
171
174
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
172
175
|
The tuple of downsample blocks to use.
|
173
176
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
|
174
|
-
|
175
|
-
the mid block layer
|
176
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat"
|
177
|
+
Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
|
178
|
+
`UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
179
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
|
177
180
|
The tuple of upsample blocks to use.
|
178
181
|
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
179
182
|
Whether to include self-attention in the basic transformer blocks, see
|
@@ -185,50 +188,58 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
185
188
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
186
189
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
187
190
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
188
|
-
If `None`,
|
191
|
+
If `None`, normalization and activation layers is skipped in post-processing.
|
189
192
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
190
193
|
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
191
194
|
The dimension of the cross attention features.
|
195
|
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
196
|
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
197
|
+
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
|
198
|
+
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
|
192
199
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
193
200
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
194
201
|
dimension to `cross_attention_dim`.
|
195
|
-
encoder_hid_dim_type (`str`, *optional*, defaults to None):
|
196
|
-
If given, the `encoder_hidden_states` and potentially other embeddings
|
202
|
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
203
|
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
197
204
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
198
205
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
206
|
+
num_attention_heads (`int`, *optional*):
|
207
|
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
199
208
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
200
|
-
for
|
201
|
-
class_embed_type (`str`, *optional*, defaults to None):
|
209
|
+
for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`.
|
210
|
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
202
211
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
203
212
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
204
|
-
addition_embed_type (`str`, *optional*, defaults to None):
|
213
|
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
205
214
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
206
215
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
207
|
-
|
216
|
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
217
|
+
Dimension for the timestep embeddings.
|
218
|
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
208
219
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
209
220
|
class conditioning with `class_embed_type` equal to `None`.
|
210
|
-
time_embedding_type (`str`, *optional*,
|
221
|
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
211
222
|
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
212
|
-
time_embedding_dim (`int`, *optional*,
|
223
|
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
213
224
|
An optional override for the dimension of the projected time embedding.
|
214
|
-
time_embedding_act_fn (`str`, *optional*,
|
215
|
-
Optional activation function to use on the time embeddings
|
216
|
-
|
217
|
-
timestep_post_act (`str
|
225
|
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
226
|
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
227
|
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
228
|
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
218
229
|
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
219
|
-
time_cond_proj_dim (`int`, *optional*,
|
220
|
-
The dimension of `cond_proj` layer in timestep embedding.
|
230
|
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
231
|
+
The dimension of `cond_proj` layer in the timestep embedding.
|
221
232
|
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
222
233
|
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
223
234
|
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
224
|
-
|
235
|
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
225
236
|
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
226
237
|
embeddings with the class embeddings.
|
227
238
|
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
228
239
|
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
|
229
|
-
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None
|
230
|
-
`only_cross_attention` value
|
231
|
-
|
240
|
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
241
|
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
242
|
+
otherwise.
|
232
243
|
"""
|
233
244
|
|
234
245
|
_supports_gradient_checkpointing = True
|
@@ -264,13 +275,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
264
275
|
norm_num_groups: Optional[int] = 32,
|
265
276
|
norm_eps: float = 1e-5,
|
266
277
|
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
278
|
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
267
279
|
encoder_hid_dim: Optional[int] = None,
|
268
280
|
encoder_hid_dim_type: Optional[str] = None,
|
269
281
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
282
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
270
283
|
dual_cross_attention: bool = False,
|
271
284
|
use_linear_projection: bool = False,
|
272
285
|
class_embed_type: Optional[str] = None,
|
273
286
|
addition_embed_type: Optional[str] = None,
|
287
|
+
addition_time_embed_dim: Optional[int] = None,
|
274
288
|
num_class_embeds: Optional[int] = None,
|
275
289
|
upcast_attention: bool = False,
|
276
290
|
resnet_time_scale_shift: str = "default",
|
@@ -293,6 +307,22 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
293
307
|
|
294
308
|
self.sample_size = sample_size
|
295
309
|
|
310
|
+
if num_attention_heads is not None:
|
311
|
+
raise ValueError(
|
312
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
|
313
|
+
" because of a naming issue as described in"
|
314
|
+
" https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
|
315
|
+
" `num_attention_heads` will only be supported in diffusers v0.19."
|
316
|
+
)
|
317
|
+
|
318
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
319
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
320
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
321
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
322
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
323
|
+
# which is why we correct for the naming here.
|
324
|
+
num_attention_heads = num_attention_heads or attention_head_dim
|
325
|
+
|
296
326
|
# Check inputs
|
297
327
|
if len(down_block_types) != len(up_block_types):
|
298
328
|
raise ValueError(
|
@@ -312,6 +342,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
312
342
|
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
313
343
|
)
|
314
344
|
|
345
|
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
346
|
+
raise ValueError(
|
347
|
+
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
|
348
|
+
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
|
349
|
+
)
|
350
|
+
|
315
351
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
316
352
|
raise ValueError(
|
317
353
|
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
|
@@ -384,7 +420,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
384
420
|
image_embed_dim=cross_attention_dim,
|
385
421
|
cross_attention_dim=cross_attention_dim,
|
386
422
|
)
|
387
|
-
|
423
|
+
elif encoder_hid_dim_type == "image_proj":
|
424
|
+
# Kandinsky 2.2
|
425
|
+
self.encoder_hid_proj = ImageProjection(
|
426
|
+
image_embed_dim=encoder_hid_dim,
|
427
|
+
cross_attention_dim=cross_attention_dim,
|
428
|
+
)
|
388
429
|
elif encoder_hid_dim_type is not None:
|
389
430
|
raise ValueError(
|
390
431
|
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
@@ -437,6 +478,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
437
478
|
self.add_embedding = TextImageTimeEmbedding(
|
438
479
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
439
480
|
)
|
481
|
+
elif addition_embed_type == "text_time":
|
482
|
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
483
|
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
484
|
+
elif addition_embed_type == "image":
|
485
|
+
# Kandinsky 2.2
|
486
|
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
487
|
+
elif addition_embed_type == "image_hint":
|
488
|
+
# Kandinsky 2.2 ControlNet
|
489
|
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
440
490
|
elif addition_embed_type is not None:
|
441
491
|
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
442
492
|
|
@@ -457,6 +507,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
457
507
|
if mid_block_only_cross_attention is None:
|
458
508
|
mid_block_only_cross_attention = False
|
459
509
|
|
510
|
+
if isinstance(num_attention_heads, int):
|
511
|
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
512
|
+
|
460
513
|
if isinstance(attention_head_dim, int):
|
461
514
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
462
515
|
|
@@ -466,6 +519,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
466
519
|
if isinstance(layers_per_block, int):
|
467
520
|
layers_per_block = [layers_per_block] * len(down_block_types)
|
468
521
|
|
522
|
+
if isinstance(transformer_layers_per_block, int):
|
523
|
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
524
|
+
|
469
525
|
if class_embeddings_concat:
|
470
526
|
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
471
527
|
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
@@ -484,6 +540,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
484
540
|
down_block = get_down_block(
|
485
541
|
down_block_type,
|
486
542
|
num_layers=layers_per_block[i],
|
543
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
487
544
|
in_channels=input_channel,
|
488
545
|
out_channels=output_channel,
|
489
546
|
temb_channels=blocks_time_embed_dim,
|
@@ -492,7 +549,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
492
549
|
resnet_act_fn=act_fn,
|
493
550
|
resnet_groups=norm_num_groups,
|
494
551
|
cross_attention_dim=cross_attention_dim[i],
|
495
|
-
|
552
|
+
num_attention_heads=num_attention_heads[i],
|
496
553
|
downsample_padding=downsample_padding,
|
497
554
|
dual_cross_attention=dual_cross_attention,
|
498
555
|
use_linear_projection=use_linear_projection,
|
@@ -502,12 +559,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
502
559
|
resnet_skip_time_act=resnet_skip_time_act,
|
503
560
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
504
561
|
cross_attention_norm=cross_attention_norm,
|
562
|
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
505
563
|
)
|
506
564
|
self.down_blocks.append(down_block)
|
507
565
|
|
508
566
|
# mid
|
509
567
|
if mid_block_type == "UNetMidBlockFlatCrossAttn":
|
510
568
|
self.mid_block = UNetMidBlockFlatCrossAttn(
|
569
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
511
570
|
in_channels=block_out_channels[-1],
|
512
571
|
temb_channels=blocks_time_embed_dim,
|
513
572
|
resnet_eps=norm_eps,
|
@@ -515,7 +574,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
515
574
|
output_scale_factor=mid_block_scale_factor,
|
516
575
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
517
576
|
cross_attention_dim=cross_attention_dim[-1],
|
518
|
-
|
577
|
+
num_attention_heads=num_attention_heads[-1],
|
519
578
|
resnet_groups=norm_num_groups,
|
520
579
|
dual_cross_attention=dual_cross_attention,
|
521
580
|
use_linear_projection=use_linear_projection,
|
@@ -529,7 +588,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
529
588
|
resnet_act_fn=act_fn,
|
530
589
|
output_scale_factor=mid_block_scale_factor,
|
531
590
|
cross_attention_dim=cross_attention_dim[-1],
|
532
|
-
|
591
|
+
attention_head_dim=attention_head_dim[-1],
|
533
592
|
resnet_groups=norm_num_groups,
|
534
593
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
535
594
|
skip_time_act=resnet_skip_time_act,
|
@@ -546,9 +605,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
546
605
|
|
547
606
|
# up
|
548
607
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
549
|
-
|
608
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
550
609
|
reversed_layers_per_block = list(reversed(layers_per_block))
|
551
610
|
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
611
|
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
552
612
|
only_cross_attention = list(reversed(only_cross_attention))
|
553
613
|
|
554
614
|
output_channel = reversed_block_out_channels[0]
|
@@ -569,6 +629,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
569
629
|
up_block = get_up_block(
|
570
630
|
up_block_type,
|
571
631
|
num_layers=reversed_layers_per_block[i] + 1,
|
632
|
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
572
633
|
in_channels=input_channel,
|
573
634
|
out_channels=output_channel,
|
574
635
|
prev_output_channel=prev_output_channel,
|
@@ -578,7 +639,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
578
639
|
resnet_act_fn=act_fn,
|
579
640
|
resnet_groups=norm_num_groups,
|
580
641
|
cross_attention_dim=reversed_cross_attention_dim[i],
|
581
|
-
|
642
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
582
643
|
dual_cross_attention=dual_cross_attention,
|
583
644
|
use_linear_projection=use_linear_projection,
|
584
645
|
only_cross_attention=only_cross_attention[i],
|
@@ -587,6 +648,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
587
648
|
resnet_skip_time_act=resnet_skip_time_act,
|
588
649
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
589
650
|
cross_attention_norm=cross_attention_norm,
|
651
|
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
590
652
|
)
|
591
653
|
self.up_blocks.append(up_block)
|
592
654
|
prev_output_channel = output_channel
|
@@ -634,11 +696,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
634
696
|
|
635
697
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
636
698
|
r"""
|
699
|
+
Sets the attention processor to use to compute attention.
|
700
|
+
|
637
701
|
Parameters:
|
638
|
-
|
702
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
639
703
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
640
|
-
|
641
|
-
|
704
|
+
for **all** `Attention` layers.
|
705
|
+
|
706
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
707
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
642
708
|
|
643
709
|
"""
|
644
710
|
count = len(self.attn_processors.keys())
|
@@ -672,13 +738,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
672
738
|
r"""
|
673
739
|
Enable sliced attention computation.
|
674
740
|
|
675
|
-
When this option is enabled, the attention module
|
676
|
-
|
741
|
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
742
|
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
677
743
|
|
678
744
|
Args:
|
679
745
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
680
|
-
When `"auto"`,
|
681
|
-
`"max"`, maximum amount of memory
|
746
|
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
747
|
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
682
748
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
683
749
|
must be a multiple of `slice_size`.
|
684
750
|
"""
|
@@ -753,29 +819,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
753
819
|
return_dict: bool = True,
|
754
820
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
755
821
|
r"""
|
822
|
+
The [`UNetFlatConditionModel`] forward method.
|
823
|
+
|
756
824
|
Args:
|
757
|
-
sample (`torch.FloatTensor`):
|
758
|
-
|
759
|
-
|
825
|
+
sample (`torch.FloatTensor`):
|
826
|
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
827
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
828
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
829
|
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
760
830
|
encoder_attention_mask (`torch.Tensor`):
|
761
|
-
(batch, sequence_length)
|
762
|
-
|
763
|
-
corresponding to "discard" tokens.
|
831
|
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
832
|
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
833
|
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
764
834
|
return_dict (`bool`, *optional*, defaults to `True`):
|
765
|
-
Whether or not to return a [
|
835
|
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
836
|
+
tuple.
|
766
837
|
cross_attention_kwargs (`dict`, *optional*):
|
767
|
-
A kwargs dictionary that if specified is passed along to the `
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
|
772
|
-
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
|
773
|
-
`addition_embed_type` for more information.
|
838
|
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
839
|
+
added_cond_kwargs: (`dict`, *optional*):
|
840
|
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
841
|
+
are passed along to the UNet blocks.
|
774
842
|
|
775
843
|
Returns:
|
776
844
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
777
|
-
|
778
|
-
|
845
|
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
846
|
+
a `tuple` is returned where the first element is the sample tensor.
|
779
847
|
"""
|
780
848
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
781
849
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
@@ -841,6 +909,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
841
909
|
t_emb = t_emb.to(dtype=sample.dtype)
|
842
910
|
|
843
911
|
emb = self.time_embedding(t_emb, timestep_cond)
|
912
|
+
aug_emb = None
|
844
913
|
|
845
914
|
if self.class_embedding is not None:
|
846
915
|
if class_labels is None:
|
@@ -862,9 +931,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
862
931
|
|
863
932
|
if self.config.addition_embed_type == "text":
|
864
933
|
aug_emb = self.add_embedding(encoder_hidden_states)
|
865
|
-
emb = emb + aug_emb
|
866
934
|
elif self.config.addition_embed_type == "text_image":
|
867
|
-
#
|
935
|
+
# Kandinsky 2.1 - style
|
868
936
|
if "image_embeds" not in added_cond_kwargs:
|
869
937
|
raise ValueError(
|
870
938
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
|
@@ -873,9 +941,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
873
941
|
|
874
942
|
image_embs = added_cond_kwargs.get("image_embeds")
|
875
943
|
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
876
|
-
|
877
944
|
aug_emb = self.add_embedding(text_embs, image_embs)
|
878
|
-
|
945
|
+
elif self.config.addition_embed_type == "text_time":
|
946
|
+
if "text_embeds" not in added_cond_kwargs:
|
947
|
+
raise ValueError(
|
948
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
949
|
+
" the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
950
|
+
)
|
951
|
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
952
|
+
if "time_ids" not in added_cond_kwargs:
|
953
|
+
raise ValueError(
|
954
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
955
|
+
" the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
956
|
+
)
|
957
|
+
time_ids = added_cond_kwargs.get("time_ids")
|
958
|
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
959
|
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
960
|
+
|
961
|
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
962
|
+
add_embeds = add_embeds.to(emb.dtype)
|
963
|
+
aug_emb = self.add_embedding(add_embeds)
|
964
|
+
elif self.config.addition_embed_type == "image":
|
965
|
+
# Kandinsky 2.2 - style
|
966
|
+
if "image_embeds" not in added_cond_kwargs:
|
967
|
+
raise ValueError(
|
968
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
|
969
|
+
" keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
970
|
+
)
|
971
|
+
image_embs = added_cond_kwargs.get("image_embeds")
|
972
|
+
aug_emb = self.add_embedding(image_embs)
|
973
|
+
elif self.config.addition_embed_type == "image_hint":
|
974
|
+
# Kandinsky 2.2 - style
|
975
|
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
976
|
+
raise ValueError(
|
977
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
|
978
|
+
" the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
979
|
+
)
|
980
|
+
image_embs = added_cond_kwargs.get("image_embeds")
|
981
|
+
hint = added_cond_kwargs.get("hint")
|
982
|
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
983
|
+
sample = torch.cat([sample, hint], dim=1)
|
984
|
+
|
985
|
+
emb = emb + aug_emb if aug_emb is not None else emb
|
879
986
|
|
880
987
|
if self.time_embed_act is not None:
|
881
988
|
emb = self.time_embed_act(emb)
|
@@ -892,7 +999,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
892
999
|
|
893
1000
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
894
1001
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
895
|
-
|
1002
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
1003
|
+
# Kandinsky 2.2 - style
|
1004
|
+
if "image_embeds" not in added_cond_kwargs:
|
1005
|
+
raise ValueError(
|
1006
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
|
1007
|
+
" the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1008
|
+
)
|
1009
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1010
|
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
896
1011
|
# 2. pre-process
|
897
1012
|
sample = self.conv_in(sample)
|
898
1013
|
|
@@ -1187,12 +1302,13 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1187
1302
|
temb_channels: int,
|
1188
1303
|
dropout: float = 0.0,
|
1189
1304
|
num_layers: int = 1,
|
1305
|
+
transformer_layers_per_block: int = 1,
|
1190
1306
|
resnet_eps: float = 1e-6,
|
1191
1307
|
resnet_time_scale_shift: str = "default",
|
1192
1308
|
resnet_act_fn: str = "swish",
|
1193
1309
|
resnet_groups: int = 32,
|
1194
1310
|
resnet_pre_norm: bool = True,
|
1195
|
-
|
1311
|
+
num_attention_heads=1,
|
1196
1312
|
cross_attention_dim=1280,
|
1197
1313
|
output_scale_factor=1.0,
|
1198
1314
|
downsample_padding=1,
|
@@ -1207,7 +1323,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1207
1323
|
attentions = []
|
1208
1324
|
|
1209
1325
|
self.has_cross_attention = True
|
1210
|
-
self.
|
1326
|
+
self.num_attention_heads = num_attention_heads
|
1211
1327
|
|
1212
1328
|
for i in range(num_layers):
|
1213
1329
|
in_channels = in_channels if i == 0 else out_channels
|
@@ -1228,10 +1344,10 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1228
1344
|
if not dual_cross_attention:
|
1229
1345
|
attentions.append(
|
1230
1346
|
Transformer2DModel(
|
1231
|
-
|
1232
|
-
out_channels //
|
1347
|
+
num_attention_heads,
|
1348
|
+
out_channels // num_attention_heads,
|
1233
1349
|
in_channels=out_channels,
|
1234
|
-
num_layers=
|
1350
|
+
num_layers=transformer_layers_per_block,
|
1235
1351
|
cross_attention_dim=cross_attention_dim,
|
1236
1352
|
norm_num_groups=resnet_groups,
|
1237
1353
|
use_linear_projection=use_linear_projection,
|
@@ -1242,8 +1358,8 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1242
1358
|
else:
|
1243
1359
|
attentions.append(
|
1244
1360
|
DualTransformer2DModel(
|
1245
|
-
|
1246
|
-
out_channels //
|
1361
|
+
num_attention_heads,
|
1362
|
+
out_channels // num_attention_heads,
|
1247
1363
|
in_channels=out_channels,
|
1248
1364
|
num_layers=1,
|
1249
1365
|
cross_attention_dim=cross_attention_dim,
|
@@ -1421,12 +1537,13 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1421
1537
|
temb_channels: int,
|
1422
1538
|
dropout: float = 0.0,
|
1423
1539
|
num_layers: int = 1,
|
1540
|
+
transformer_layers_per_block: int = 1,
|
1424
1541
|
resnet_eps: float = 1e-6,
|
1425
1542
|
resnet_time_scale_shift: str = "default",
|
1426
1543
|
resnet_act_fn: str = "swish",
|
1427
1544
|
resnet_groups: int = 32,
|
1428
1545
|
resnet_pre_norm: bool = True,
|
1429
|
-
|
1546
|
+
num_attention_heads=1,
|
1430
1547
|
cross_attention_dim=1280,
|
1431
1548
|
output_scale_factor=1.0,
|
1432
1549
|
add_upsample=True,
|
@@ -1440,7 +1557,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1440
1557
|
attentions = []
|
1441
1558
|
|
1442
1559
|
self.has_cross_attention = True
|
1443
|
-
self.
|
1560
|
+
self.num_attention_heads = num_attention_heads
|
1444
1561
|
|
1445
1562
|
for i in range(num_layers):
|
1446
1563
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
@@ -1463,10 +1580,10 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1463
1580
|
if not dual_cross_attention:
|
1464
1581
|
attentions.append(
|
1465
1582
|
Transformer2DModel(
|
1466
|
-
|
1467
|
-
out_channels //
|
1583
|
+
num_attention_heads,
|
1584
|
+
out_channels // num_attention_heads,
|
1468
1585
|
in_channels=out_channels,
|
1469
|
-
num_layers=
|
1586
|
+
num_layers=transformer_layers_per_block,
|
1470
1587
|
cross_attention_dim=cross_attention_dim,
|
1471
1588
|
norm_num_groups=resnet_groups,
|
1472
1589
|
use_linear_projection=use_linear_projection,
|
@@ -1477,8 +1594,8 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1477
1594
|
else:
|
1478
1595
|
attentions.append(
|
1479
1596
|
DualTransformer2DModel(
|
1480
|
-
|
1481
|
-
out_channels //
|
1597
|
+
num_attention_heads,
|
1598
|
+
out_channels // num_attention_heads,
|
1482
1599
|
in_channels=out_channels,
|
1483
1600
|
num_layers=1,
|
1484
1601
|
cross_attention_dim=cross_attention_dim,
|
@@ -1567,12 +1684,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1567
1684
|
temb_channels: int,
|
1568
1685
|
dropout: float = 0.0,
|
1569
1686
|
num_layers: int = 1,
|
1687
|
+
transformer_layers_per_block: int = 1,
|
1570
1688
|
resnet_eps: float = 1e-6,
|
1571
1689
|
resnet_time_scale_shift: str = "default",
|
1572
1690
|
resnet_act_fn: str = "swish",
|
1573
1691
|
resnet_groups: int = 32,
|
1574
1692
|
resnet_pre_norm: bool = True,
|
1575
|
-
|
1693
|
+
num_attention_heads=1,
|
1576
1694
|
output_scale_factor=1.0,
|
1577
1695
|
cross_attention_dim=1280,
|
1578
1696
|
dual_cross_attention=False,
|
@@ -1582,7 +1700,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1582
1700
|
super().__init__()
|
1583
1701
|
|
1584
1702
|
self.has_cross_attention = True
|
1585
|
-
self.
|
1703
|
+
self.num_attention_heads = num_attention_heads
|
1586
1704
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
1587
1705
|
|
1588
1706
|
# there is always at least one resnet
|
@@ -1606,10 +1724,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1606
1724
|
if not dual_cross_attention:
|
1607
1725
|
attentions.append(
|
1608
1726
|
Transformer2DModel(
|
1609
|
-
|
1610
|
-
in_channels //
|
1727
|
+
num_attention_heads,
|
1728
|
+
in_channels // num_attention_heads,
|
1611
1729
|
in_channels=in_channels,
|
1612
|
-
num_layers=
|
1730
|
+
num_layers=transformer_layers_per_block,
|
1613
1731
|
cross_attention_dim=cross_attention_dim,
|
1614
1732
|
norm_num_groups=resnet_groups,
|
1615
1733
|
use_linear_projection=use_linear_projection,
|
@@ -1619,8 +1737,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1619
1737
|
else:
|
1620
1738
|
attentions.append(
|
1621
1739
|
DualTransformer2DModel(
|
1622
|
-
|
1623
|
-
in_channels //
|
1740
|
+
num_attention_heads,
|
1741
|
+
in_channels // num_attention_heads,
|
1624
1742
|
in_channels=in_channels,
|
1625
1743
|
num_layers=1,
|
1626
1744
|
cross_attention_dim=cross_attention_dim,
|
@@ -1682,7 +1800,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
1682
1800
|
resnet_act_fn: str = "swish",
|
1683
1801
|
resnet_groups: int = 32,
|
1684
1802
|
resnet_pre_norm: bool = True,
|
1685
|
-
|
1803
|
+
attention_head_dim=1,
|
1686
1804
|
output_scale_factor=1.0,
|
1687
1805
|
cross_attention_dim=1280,
|
1688
1806
|
skip_time_act=False,
|
@@ -1693,10 +1811,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
1693
1811
|
|
1694
1812
|
self.has_cross_attention = True
|
1695
1813
|
|
1696
|
-
self.
|
1814
|
+
self.attention_head_dim = attention_head_dim
|
1697
1815
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
1698
1816
|
|
1699
|
-
self.num_heads = in_channels // self.
|
1817
|
+
self.num_heads = in_channels // self.attention_head_dim
|
1700
1818
|
|
1701
1819
|
# there is always at least one resnet
|
1702
1820
|
resnets = [
|
@@ -1726,7 +1844,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
1726
1844
|
query_dim=in_channels,
|
1727
1845
|
cross_attention_dim=in_channels,
|
1728
1846
|
heads=self.num_heads,
|
1729
|
-
dim_head=
|
1847
|
+
dim_head=self.attention_head_dim,
|
1730
1848
|
added_kv_proj_dim=cross_attention_dim,
|
1731
1849
|
norm_num_groups=resnet_groups,
|
1732
1850
|
bias=True,
|