diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1992 @@
1
+ import math
2
+ from typing import Dict, Optional, Tuple, Union, Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ import torch.nn.functional as F
8
+
9
+ from diffsynth_engine.models.base import PreTrainedModel
10
+
11
+
12
+ ACT2CLS = {
13
+ "swish": nn.SiLU,
14
+ "silu": nn.SiLU,
15
+ "mish": nn.Mish,
16
+ "gelu": nn.GELU,
17
+ "relu": nn.ReLU,
18
+ }
19
+
20
+
21
+ def get_activation(act_fn: str) -> nn.Module:
22
+ """Helper function to get activation function from string.
23
+
24
+ Args:
25
+ act_fn (str): Name of activation function.
26
+
27
+ Returns:
28
+ nn.Module: Activation function.
29
+ """
30
+
31
+ act_fn = act_fn.lower()
32
+ if act_fn in ACT2CLS:
33
+ return ACT2CLS[act_fn]()
34
+ else:
35
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
36
+
37
+
38
+ class ResnetBlock2D(nn.Module):
39
+ r"""
40
+ A Resnet block.
41
+
42
+ Parameters:
43
+ in_channels (`int`): The number of channels in the input.
44
+ out_channels (`int`, *optional*, default to be `None`):
45
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
46
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
47
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
48
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
49
+ groups_out (`int`, *optional*, default to None):
50
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
51
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
52
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
53
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
54
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
55
+ stronger conditioning with scale and shift.
56
+ kernel (`torch.Tensor`, optional, default to None): FIR filter, see
57
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
58
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
59
+ use_in_shortcut (`bool`, *optional*, default to `True`):
60
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
61
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
62
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
63
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
64
+ `conv_shortcut` output.
65
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
66
+ If None, same as `out_channels`.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ in_channels: int,
73
+ out_channels: Optional[int] = None,
74
+ conv_shortcut: bool = False,
75
+ dropout: float = 0.0,
76
+ temb_channels: int = 512,
77
+ groups: int = 32,
78
+ groups_out: Optional[int] = None,
79
+ pre_norm: bool = True,
80
+ eps: float = 1e-6,
81
+ non_linearity: str = "swish",
82
+ skip_time_act: bool = False,
83
+ time_embedding_norm: str = "default", # default, scale_shift,
84
+ kernel: Optional[torch.Tensor] = None,
85
+ output_scale_factor: float = 1.0,
86
+ use_in_shortcut: Optional[bool] = None,
87
+ up: bool = False,
88
+ down: bool = False,
89
+ conv_shortcut_bias: bool = True,
90
+ conv_2d_out_channels: Optional[int] = None,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.pre_norm = True
95
+ self.in_channels = in_channels
96
+ out_channels = in_channels if out_channels is None else out_channels
97
+ self.out_channels = out_channels
98
+ self.use_conv_shortcut = conv_shortcut
99
+ self.up = up
100
+ self.down = down
101
+ self.output_scale_factor = output_scale_factor
102
+ self.time_embedding_norm = time_embedding_norm
103
+ self.skip_time_act = skip_time_act
104
+
105
+ if groups_out is None:
106
+ groups_out = groups
107
+
108
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
109
+
110
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
111
+
112
+ if temb_channels is not None:
113
+ if self.time_embedding_norm == "default":
114
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
115
+ elif self.time_embedding_norm == "scale_shift":
116
+ self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
117
+ else:
118
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
119
+ else:
120
+ self.time_emb_proj = None
121
+
122
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
123
+
124
+ self.dropout = torch.nn.Dropout(dropout)
125
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
126
+ self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
127
+
128
+ self.nonlinearity = get_activation(non_linearity)
129
+
130
+ self.upsample = self.downsample = None
131
+ if self.up:
132
+ if kernel == "fir":
133
+ fir_kernel = (1, 3, 3, 1)
134
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
135
+ elif kernel == "sde_vp":
136
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
137
+ else:
138
+ self.upsample = Upsample2D(in_channels, use_conv=False)
139
+ elif self.down:
140
+ if kernel == "fir":
141
+ fir_kernel = (1, 3, 3, 1)
142
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
143
+ elif kernel == "sde_vp":
144
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
145
+ else:
146
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
147
+
148
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
149
+
150
+ self.conv_shortcut = None
151
+ if self.use_in_shortcut:
152
+ self.conv_shortcut = nn.Conv2d(
153
+ in_channels,
154
+ conv_2d_out_channels,
155
+ kernel_size=1,
156
+ stride=1,
157
+ padding=0,
158
+ bias=conv_shortcut_bias,
159
+ )
160
+
161
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
162
+ hidden_states = input_tensor
163
+
164
+ hidden_states = self.norm1(hidden_states)
165
+ hidden_states = self.nonlinearity(hidden_states)
166
+
167
+ if self.upsample is not None:
168
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
169
+ if hidden_states.shape[0] >= 64:
170
+ input_tensor = input_tensor.contiguous()
171
+ hidden_states = hidden_states.contiguous()
172
+ input_tensor = self.upsample(input_tensor)
173
+ hidden_states = self.upsample(hidden_states)
174
+ elif self.downsample is not None:
175
+ input_tensor = self.downsample(input_tensor)
176
+ hidden_states = self.downsample(hidden_states)
177
+
178
+ hidden_states = self.conv1(hidden_states)
179
+
180
+ if self.time_emb_proj is not None:
181
+ if not self.skip_time_act:
182
+ temb = self.nonlinearity(temb)
183
+ temb = self.time_emb_proj(temb)[:, :, None, None]
184
+
185
+ if self.time_embedding_norm == "default":
186
+ if temb is not None:
187
+ hidden_states = hidden_states + temb
188
+ hidden_states = self.norm2(hidden_states)
189
+ elif self.time_embedding_norm == "scale_shift":
190
+ if temb is None:
191
+ raise ValueError(
192
+ f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
193
+ )
194
+ time_scale, time_shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = self.norm2(hidden_states)
196
+ hidden_states = hidden_states * (1 + time_scale) + time_shift
197
+ else:
198
+ hidden_states = self.norm2(hidden_states)
199
+
200
+ hidden_states = self.nonlinearity(hidden_states)
201
+
202
+ hidden_states = self.dropout(hidden_states)
203
+ hidden_states = self.conv2(hidden_states)
204
+
205
+ if self.conv_shortcut is not None:
206
+ input_tensor = self.conv_shortcut(input_tensor.contiguous())
207
+
208
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
209
+
210
+ return output_tensor
211
+
212
+
213
+ class Downsample2D(nn.Module):
214
+ """A 2D downsampling layer with an optional convolution.
215
+
216
+ Parameters:
217
+ channels (`int`):
218
+ number of channels in the inputs and outputs.
219
+ use_conv (`bool`, default `False`):
220
+ option to use a convolution.
221
+ out_channels (`int`, optional):
222
+ number of output channels. Defaults to `channels`.
223
+ padding (`int`, default `1`):
224
+ padding for the convolution.
225
+ name (`str`, default `conv`):
226
+ name of the downsampling 2D layer.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ channels: int,
232
+ use_conv: bool = False,
233
+ out_channels: Optional[int] = None,
234
+ padding: int = 1,
235
+ name: str = "conv",
236
+ kernel_size=3,
237
+ norm_type=None,
238
+ eps=None,
239
+ elementwise_affine=None,
240
+ bias=True,
241
+ ):
242
+ super().__init__()
243
+ self.channels = channels
244
+ self.out_channels = out_channels or channels
245
+ self.use_conv = use_conv
246
+ self.padding = padding
247
+ stride = 2
248
+ self.name = name
249
+
250
+ if norm_type == "ln_norm":
251
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
252
+ elif norm_type == "rms_norm":
253
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
254
+ elif norm_type is None:
255
+ self.norm = None
256
+ else:
257
+ raise ValueError(f"unknown norm_type: {norm_type}")
258
+
259
+ if use_conv:
260
+ conv = nn.Conv2d(
261
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
262
+ )
263
+ else:
264
+ assert self.channels == self.out_channels
265
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
266
+
267
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
268
+ if name == "conv":
269
+ self.Conv2d_0 = conv
270
+ self.conv = conv
271
+ elif name == "Conv2d_0":
272
+ self.conv = conv
273
+ else:
274
+ self.conv = conv
275
+
276
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
277
+ assert hidden_states.shape[1] == self.channels
278
+
279
+ if self.norm is not None:
280
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
281
+
282
+ if self.use_conv and self.padding == 0:
283
+ pad = (0, 1, 0, 1)
284
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
285
+
286
+ assert hidden_states.shape[1] == self.channels
287
+
288
+ hidden_states = self.conv(hidden_states)
289
+
290
+ return hidden_states
291
+
292
+
293
+ class Upsample2D(nn.Module):
294
+ """A 2D upsampling layer with an optional convolution.
295
+
296
+ Parameters:
297
+ channels (`int`):
298
+ number of channels in the inputs and outputs.
299
+ use_conv (`bool`, default `False`):
300
+ option to use a convolution.
301
+ use_conv_transpose (`bool`, default `False`):
302
+ option to use a convolution transpose.
303
+ out_channels (`int`, optional):
304
+ number of output channels. Defaults to `channels`.
305
+ name (`str`, default `conv`):
306
+ name of the upsampling 2D layer.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ channels: int,
312
+ use_conv: bool = False,
313
+ use_conv_transpose: bool = False,
314
+ out_channels: Optional[int] = None,
315
+ name: str = "conv",
316
+ kernel_size: Optional[int] = None,
317
+ padding=1,
318
+ norm_type=None,
319
+ eps=None,
320
+ elementwise_affine=None,
321
+ bias=True,
322
+ interpolate=True,
323
+ ):
324
+ super().__init__()
325
+ self.channels = channels
326
+ self.out_channels = out_channels or channels
327
+ self.use_conv = use_conv
328
+ self.use_conv_transpose = use_conv_transpose
329
+ self.name = name
330
+ self.interpolate = interpolate
331
+
332
+ if norm_type == "ln_norm":
333
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
334
+ elif norm_type == "rms_norm":
335
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
336
+ elif norm_type is None:
337
+ self.norm = None
338
+ else:
339
+ raise ValueError(f"unknown norm_type: {norm_type}")
340
+
341
+ conv = None
342
+ if use_conv_transpose:
343
+ if kernel_size is None:
344
+ kernel_size = 4
345
+ conv = nn.ConvTranspose2d(
346
+ channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
347
+ )
348
+ elif use_conv:
349
+ if kernel_size is None:
350
+ kernel_size = 3
351
+ conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
352
+
353
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
354
+ if name == "conv":
355
+ self.conv = conv
356
+ else:
357
+ self.Conv2d_0 = conv
358
+
359
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
360
+ assert hidden_states.shape[1] == self.channels
361
+
362
+ if self.norm is not None:
363
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
364
+
365
+ if self.use_conv_transpose:
366
+ return self.conv(hidden_states)
367
+
368
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
369
+ # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
370
+ dtype = hidden_states.dtype
371
+ if dtype == torch.bfloat16:
372
+ hidden_states = hidden_states.to(torch.float32)
373
+
374
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
375
+ if hidden_states.shape[0] >= 64:
376
+ hidden_states = hidden_states.contiguous()
377
+
378
+ # if `output_size` is passed we force the interpolation output
379
+ # size and do not make use of `scale_factor=2`
380
+ if self.interpolate:
381
+ # upsample_nearest_nhwc also fails when the number of output elements is large
382
+ # https://github.com/pytorch/pytorch/issues/141831
383
+ scale_factor = (
384
+ 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])
385
+ )
386
+ if hidden_states.numel() * scale_factor > pow(2, 31):
387
+ hidden_states = hidden_states.contiguous()
388
+
389
+ if output_size is None:
390
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
391
+ else:
392
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
393
+
394
+ # Cast back to original dtype
395
+ if dtype == torch.bfloat16:
396
+ hidden_states = hidden_states.to(dtype)
397
+
398
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
399
+ if self.use_conv:
400
+ if self.name == "conv":
401
+ hidden_states = self.conv(hidden_states)
402
+ else:
403
+ hidden_states = self.Conv2d_0(hidden_states)
404
+
405
+ return hidden_states
406
+
407
+
408
+ class Attention(nn.Module):
409
+ r"""
410
+ A cross attention layer.
411
+
412
+ Parameters:
413
+ query_dim (`int`):
414
+ The number of channels in the query.
415
+ cross_attention_dim (`int`, *optional*):
416
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
417
+ heads (`int`, *optional*, defaults to 8):
418
+ The number of heads to use for multi-head attention.
419
+ kv_heads (`int`, *optional*, defaults to `None`):
420
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
421
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
422
+ Query Attention (MQA) otherwise GQA is used.
423
+ dim_head (`int`, *optional*, defaults to 64):
424
+ The number of channels in each head.
425
+ dropout (`float`, *optional*, defaults to 0.0):
426
+ The dropout probability to use.
427
+ bias (`bool`, *optional*, defaults to False):
428
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
429
+ upcast_attention (`bool`, *optional*, defaults to False):
430
+ Set to `True` to upcast the attention computation to `float32`.
431
+ upcast_softmax (`bool`, *optional*, defaults to False):
432
+ Set to `True` to upcast the softmax computation to `float32`.
433
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
434
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
435
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
436
+ The number of groups to use for the group norm in the cross attention.
437
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
438
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
439
+ norm_num_groups (`int`, *optional*, defaults to `None`):
440
+ The number of groups to use for the group norm in the attention.
441
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
442
+ The number of channels to use for the spatial normalization.
443
+ out_bias (`bool`, *optional*, defaults to `True`):
444
+ Set to `True` to use a bias in the output linear layer.
445
+ scale_qk (`bool`, *optional*, defaults to `True`):
446
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
447
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
448
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
449
+ `added_kv_proj_dim` is not `None`.
450
+ eps (`float`, *optional*, defaults to 1e-5):
451
+ An additional value added to the denominator in group normalization that is used for numerical stability.
452
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
453
+ A factor to rescale the output by dividing it with this value.
454
+ residual_connection (`bool`, *optional*, defaults to `False`):
455
+ Set to `True` to add the residual connection to the output.
456
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
457
+ Set to `True` if the attention block is loaded from a deprecated state dict.
458
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
459
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
460
+ `AttnProcessor` otherwise.
461
+ """
462
+
463
+ def __init__(
464
+ self,
465
+ query_dim: int,
466
+ cross_attention_dim: Optional[int] = None,
467
+ heads: int = 8,
468
+ kv_heads: Optional[int] = None,
469
+ dim_head: int = 64,
470
+ dropout: float = 0.0,
471
+ bias: bool = False,
472
+ upcast_attention: bool = False,
473
+ upcast_softmax: bool = False,
474
+ cross_attention_norm: Optional[str] = None,
475
+ cross_attention_norm_num_groups: int = 32,
476
+ qk_norm: Optional[str] = None,
477
+ added_kv_proj_dim: Optional[int] = None,
478
+ added_proj_bias: Optional[bool] = True,
479
+ norm_num_groups: Optional[int] = None,
480
+ spatial_norm_dim: Optional[int] = None,
481
+ out_bias: bool = True,
482
+ scale_qk: bool = True,
483
+ only_cross_attention: bool = False,
484
+ eps: float = 1e-5,
485
+ rescale_output_factor: float = 1.0,
486
+ residual_connection: bool = False,
487
+ _from_deprecated_attn_block: bool = False,
488
+ processor: Optional["AttnProcessor"] = None,
489
+ out_dim: int = None,
490
+ out_context_dim: int = None,
491
+ context_pre_only=None,
492
+ pre_only=False,
493
+ elementwise_affine: bool = True,
494
+ is_causal: bool = False,
495
+ ):
496
+ super().__init__()
497
+
498
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
499
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
500
+ self.query_dim = query_dim
501
+ self.use_bias = bias
502
+ self.is_cross_attention = cross_attention_dim is not None
503
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
504
+ self.upcast_attention = upcast_attention
505
+ self.upcast_softmax = upcast_softmax
506
+ self.rescale_output_factor = rescale_output_factor
507
+ self.residual_connection = residual_connection
508
+ self.dropout = dropout
509
+ self.fused_projections = False
510
+ self.out_dim = out_dim if out_dim is not None else query_dim
511
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
512
+ self.context_pre_only = context_pre_only
513
+ self.pre_only = pre_only
514
+ self.is_causal = is_causal
515
+
516
+ # we make use of this private variable to know whether this class is loaded
517
+ # with an deprecated state dict so that we can convert it on the fly
518
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
519
+
520
+ self.scale_qk = scale_qk
521
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
522
+
523
+ self.heads = out_dim // dim_head if out_dim is not None else heads
524
+ # for slice_size > 0 the attention score computation
525
+ # is split across the batch axis to save memory
526
+ # You can set slice_size with `set_attention_slice`
527
+ self.sliceable_head_dim = heads
528
+
529
+ self.added_kv_proj_dim = added_kv_proj_dim
530
+ self.only_cross_attention = only_cross_attention
531
+
532
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
533
+ raise ValueError(
534
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
535
+ )
536
+
537
+ if norm_num_groups is not None:
538
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
539
+ else:
540
+ self.group_norm = None
541
+
542
+ if spatial_norm_dim is not None:
543
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
544
+ else:
545
+ self.spatial_norm = None
546
+
547
+ if qk_norm is None:
548
+ self.norm_q = None
549
+ self.norm_k = None
550
+ elif qk_norm == "layer_norm":
551
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
552
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
553
+ elif qk_norm == "fp32_layer_norm":
554
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
555
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
556
+ elif qk_norm == "layer_norm_across_heads":
557
+ # Lumina applies qk norm across all heads
558
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
559
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
560
+ elif qk_norm == "rms_norm":
561
+ self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
562
+ self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
563
+ elif qk_norm == "rms_norm_across_heads":
564
+ # LTX applies qk norm across all heads
565
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
566
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
567
+ elif qk_norm == "l2":
568
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
569
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
570
+ else:
571
+ raise ValueError(
572
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
573
+ )
574
+
575
+ if cross_attention_norm is None:
576
+ self.norm_cross = None
577
+ elif cross_attention_norm == "layer_norm":
578
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
579
+ elif cross_attention_norm == "group_norm":
580
+ if self.added_kv_proj_dim is not None:
581
+ # The given `encoder_hidden_states` are initially of shape
582
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
583
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
584
+ # before the projection, so we need to use `added_kv_proj_dim` as
585
+ # the number of channels for the group norm.
586
+ norm_cross_num_channels = added_kv_proj_dim
587
+ else:
588
+ norm_cross_num_channels = self.cross_attention_dim
589
+
590
+ self.norm_cross = nn.GroupNorm(
591
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
592
+ )
593
+ else:
594
+ raise ValueError(
595
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
596
+ )
597
+
598
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
599
+
600
+ if not self.only_cross_attention:
601
+ # only relevant for the `AddedKVProcessor` classes
602
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
603
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
604
+ else:
605
+ self.to_k = None
606
+ self.to_v = None
607
+
608
+ self.added_proj_bias = added_proj_bias
609
+ if self.added_kv_proj_dim is not None:
610
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
611
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
612
+ if self.context_pre_only is not None:
613
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
614
+ else:
615
+ self.add_q_proj = None
616
+ self.add_k_proj = None
617
+ self.add_v_proj = None
618
+
619
+ if not self.pre_only:
620
+ self.to_out = nn.ModuleList([])
621
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
622
+ self.to_out.append(nn.Dropout(dropout))
623
+ else:
624
+ self.to_out = None
625
+
626
+ if self.context_pre_only is not None and not self.context_pre_only:
627
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
628
+ else:
629
+ self.to_add_out = None
630
+
631
+ if qk_norm is not None and added_kv_proj_dim is not None:
632
+ if qk_norm == "layer_norm":
633
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
634
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
635
+ elif qk_norm == "fp32_layer_norm":
636
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
637
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
638
+ elif qk_norm == "rms_norm":
639
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
640
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
641
+ elif qk_norm == "rms_norm_across_heads":
642
+ # Wan applies qk norm across all heads
643
+ # Wan also doesn't apply a q norm
644
+ self.norm_added_q = None
645
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
646
+ else:
647
+ raise ValueError(
648
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
649
+ )
650
+ else:
651
+ self.norm_added_q = None
652
+ self.norm_added_k = None
653
+
654
+ # set attention processor
655
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
656
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
657
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
658
+ if processor is None:
659
+ processor = (
660
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
661
+ )
662
+ self.set_processor(processor)
663
+
664
+ def set_processor(self, processor: "AttnProcessor") -> None:
665
+ r"""
666
+ Set the attention processor to use.
667
+
668
+ Args:
669
+ processor (`AttnProcessor`):
670
+ The attention processor to use.
671
+ """
672
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
673
+ # pop `processor` from `self._modules`
674
+ if (
675
+ hasattr(self, "processor")
676
+ and isinstance(self.processor, torch.nn.Module)
677
+ and not isinstance(processor, torch.nn.Module)
678
+ ):
679
+ self._modules.pop("processor")
680
+
681
+ self.processor = processor
682
+
683
+ def forward(
684
+ self,
685
+ hidden_states: torch.Tensor,
686
+ encoder_hidden_states: Optional[torch.Tensor] = None,
687
+ attention_mask: Optional[torch.Tensor] = None,
688
+ **cross_attention_kwargs,
689
+ ) -> torch.Tensor:
690
+ r"""
691
+ The forward method of the `Attention` class.
692
+
693
+ Args:
694
+ hidden_states (`torch.Tensor`):
695
+ The hidden states of the query.
696
+ encoder_hidden_states (`torch.Tensor`, *optional*):
697
+ The hidden states of the encoder.
698
+ attention_mask (`torch.Tensor`, *optional*):
699
+ The attention mask to use. If `None`, no mask is applied.
700
+ **cross_attention_kwargs:
701
+ Additional keyword arguments to pass along to the cross attention.
702
+
703
+ Returns:
704
+ `torch.Tensor`: The output of the attention layer.
705
+ """
706
+ # The `Attention` class can call different attention processors / attention functions
707
+ # here we simply pass along all tensors to the selected processor class
708
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
709
+
710
+ return self.processor(
711
+ self,
712
+ hidden_states,
713
+ encoder_hidden_states=encoder_hidden_states,
714
+ attention_mask=attention_mask,
715
+ **cross_attention_kwargs,
716
+ )
717
+
718
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
719
+ r"""
720
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
721
+ is the number of heads initialized while constructing the `Attention` class.
722
+
723
+ Args:
724
+ tensor (`torch.Tensor`): The tensor to reshape.
725
+
726
+ Returns:
727
+ `torch.Tensor`: The reshaped tensor.
728
+ """
729
+ head_size = self.heads
730
+ batch_size, seq_len, dim = tensor.shape
731
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
732
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
733
+ return tensor
734
+
735
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
736
+ r"""
737
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
738
+ the number of heads initialized while constructing the `Attention` class.
739
+
740
+ Args:
741
+ tensor (`torch.Tensor`): The tensor to reshape.
742
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
743
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
744
+
745
+ Returns:
746
+ `torch.Tensor`: The reshaped tensor.
747
+ """
748
+ head_size = self.heads
749
+ if tensor.ndim == 3:
750
+ batch_size, seq_len, dim = tensor.shape
751
+ extra_dim = 1
752
+ else:
753
+ batch_size, extra_dim, seq_len, dim = tensor.shape
754
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
755
+ tensor = tensor.permute(0, 2, 1, 3)
756
+
757
+ if out_dim == 3:
758
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
759
+
760
+ return tensor
761
+
762
+ def get_attention_scores(
763
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
764
+ ) -> torch.Tensor:
765
+ r"""
766
+ Compute the attention scores.
767
+
768
+ Args:
769
+ query (`torch.Tensor`): The query tensor.
770
+ key (`torch.Tensor`): The key tensor.
771
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
772
+
773
+ Returns:
774
+ `torch.Tensor`: The attention probabilities/scores.
775
+ """
776
+ dtype = query.dtype
777
+ if self.upcast_attention:
778
+ query = query.float()
779
+ key = key.float()
780
+
781
+ if attention_mask is None:
782
+ baddbmm_input = torch.empty(
783
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
784
+ )
785
+ beta = 0
786
+ else:
787
+ baddbmm_input = attention_mask
788
+ beta = 1
789
+
790
+ attention_scores = torch.baddbmm(
791
+ baddbmm_input,
792
+ query,
793
+ key.transpose(-1, -2),
794
+ beta=beta,
795
+ alpha=self.scale,
796
+ )
797
+ del baddbmm_input
798
+
799
+ if self.upcast_softmax:
800
+ attention_scores = attention_scores.float()
801
+
802
+ attention_probs = attention_scores.softmax(dim=-1)
803
+ del attention_scores
804
+
805
+ attention_probs = attention_probs.to(dtype)
806
+
807
+ return attention_probs
808
+
809
+ def prepare_attention_mask(
810
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
811
+ ) -> torch.Tensor:
812
+ r"""
813
+ Prepare the attention mask for the attention computation.
814
+
815
+ Args:
816
+ attention_mask (`torch.Tensor`):
817
+ The attention mask to prepare.
818
+ target_length (`int`):
819
+ The target length of the attention mask. This is the length of the attention mask after padding.
820
+ batch_size (`int`):
821
+ The batch size, which is used to repeat the attention mask.
822
+ out_dim (`int`, *optional*, defaults to `3`):
823
+ The output dimension of the attention mask. Can be either `3` or `4`.
824
+
825
+ Returns:
826
+ `torch.Tensor`: The prepared attention mask.
827
+ """
828
+ head_size = self.heads
829
+ if attention_mask is None:
830
+ return attention_mask
831
+
832
+ current_length: int = attention_mask.shape[-1]
833
+ if current_length != target_length:
834
+ if attention_mask.device.type == "mps":
835
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
836
+ # Instead, we can manually construct the padding tensor.
837
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
838
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
839
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
840
+ else:
841
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
842
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
843
+ # remaining_length: int = target_length - current_length
844
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
845
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
846
+
847
+ if out_dim == 3:
848
+ if attention_mask.shape[0] < batch_size * head_size:
849
+ attention_mask = attention_mask.repeat_interleave(
850
+ head_size, dim=0, output_size=attention_mask.shape[0] * head_size
851
+ )
852
+ elif out_dim == 4:
853
+ attention_mask = attention_mask.unsqueeze(1)
854
+ attention_mask = attention_mask.repeat_interleave(
855
+ head_size, dim=1, output_size=attention_mask.shape[1] * head_size
856
+ )
857
+
858
+ return attention_mask
859
+
860
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
861
+ r"""
862
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
863
+ `Attention` class.
864
+
865
+ Args:
866
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
867
+
868
+ Returns:
869
+ `torch.Tensor`: The normalized encoder hidden states.
870
+ """
871
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
872
+
873
+ if isinstance(self.norm_cross, nn.LayerNorm):
874
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
875
+ elif isinstance(self.norm_cross, nn.GroupNorm):
876
+ # Group norm norms along the channels dimension and expects
877
+ # input to be in the shape of (N, C, *). In this case, we want
878
+ # to norm along the hidden dimension, so we need to move
879
+ # (batch_size, sequence_length, hidden_size) ->
880
+ # (batch_size, hidden_size, sequence_length)
881
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
882
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
883
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
884
+ else:
885
+ assert False
886
+
887
+ return encoder_hidden_states
888
+
889
+ @torch.no_grad()
890
+ def fuse_projections(self, fuse=True):
891
+ device = self.to_q.weight.data.device
892
+ dtype = self.to_q.weight.data.dtype
893
+
894
+ if not self.is_cross_attention:
895
+ # fetch weight matrices.
896
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
897
+ in_features = concatenated_weights.shape[1]
898
+ out_features = concatenated_weights.shape[0]
899
+
900
+ # create a new single projection layer and copy over the weights.
901
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
902
+ self.to_qkv.weight.copy_(concatenated_weights)
903
+ if self.use_bias:
904
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
905
+ self.to_qkv.bias.copy_(concatenated_bias)
906
+
907
+ else:
908
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
909
+ in_features = concatenated_weights.shape[1]
910
+ out_features = concatenated_weights.shape[0]
911
+
912
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
913
+ self.to_kv.weight.copy_(concatenated_weights)
914
+ if self.use_bias:
915
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
916
+ self.to_kv.bias.copy_(concatenated_bias)
917
+
918
+ # handle added projections for SD3 and others.
919
+ if (
920
+ getattr(self, "add_q_proj", None) is not None
921
+ and getattr(self, "add_k_proj", None) is not None
922
+ and getattr(self, "add_v_proj", None) is not None
923
+ ):
924
+ concatenated_weights = torch.cat(
925
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
926
+ )
927
+ in_features = concatenated_weights.shape[1]
928
+ out_features = concatenated_weights.shape[0]
929
+
930
+ self.to_added_qkv = nn.Linear(
931
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
932
+ )
933
+ self.to_added_qkv.weight.copy_(concatenated_weights)
934
+ if self.added_proj_bias:
935
+ concatenated_bias = torch.cat(
936
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
937
+ )
938
+ self.to_added_qkv.bias.copy_(concatenated_bias)
939
+
940
+ self.fused_projections = fuse
941
+
942
+
943
+ class AttnProcessor2_0:
944
+ r"""
945
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
946
+ """
947
+
948
+ def __init__(self):
949
+ if not hasattr(F, "scaled_dot_product_attention"):
950
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
951
+
952
+ def __call__(
953
+ self,
954
+ attn: Attention,
955
+ hidden_states: torch.Tensor,
956
+ encoder_hidden_states: Optional[torch.Tensor] = None,
957
+ attention_mask: Optional[torch.Tensor] = None,
958
+ temb: Optional[torch.Tensor] = None,
959
+ *args,
960
+ **kwargs,
961
+ ) -> torch.Tensor:
962
+ residual = hidden_states
963
+ if attn.spatial_norm is not None:
964
+ hidden_states = attn.spatial_norm(hidden_states, temb)
965
+
966
+ input_ndim = hidden_states.ndim
967
+
968
+ if input_ndim == 4:
969
+ batch_size, channel, height, width = hidden_states.shape
970
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
971
+
972
+ batch_size, sequence_length, _ = (
973
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
974
+ )
975
+
976
+ if attention_mask is not None:
977
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
978
+ # scaled_dot_product_attention expects attention_mask shape to be
979
+ # (batch, heads, source_length, target_length)
980
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
981
+
982
+ if attn.group_norm is not None:
983
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
984
+
985
+ query = attn.to_q(hidden_states)
986
+
987
+ if encoder_hidden_states is None:
988
+ encoder_hidden_states = hidden_states
989
+ elif attn.norm_cross:
990
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
991
+
992
+ key = attn.to_k(encoder_hidden_states)
993
+ value = attn.to_v(encoder_hidden_states)
994
+
995
+ inner_dim = key.shape[-1]
996
+ head_dim = inner_dim // attn.heads
997
+
998
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
999
+
1000
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1001
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1002
+
1003
+ if attn.norm_q is not None:
1004
+ query = attn.norm_q(query)
1005
+ if attn.norm_k is not None:
1006
+ key = attn.norm_k(key)
1007
+
1008
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1009
+ # TODO: add support for attn.scale when we move to Torch 2.1
1010
+ hidden_states = F.scaled_dot_product_attention(
1011
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1012
+ )
1013
+
1014
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1015
+ hidden_states = hidden_states.to(query.dtype)
1016
+
1017
+ # linear proj
1018
+ hidden_states = attn.to_out[0](hidden_states)
1019
+ # dropout
1020
+ hidden_states = attn.to_out[1](hidden_states)
1021
+
1022
+ if input_ndim == 4:
1023
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1024
+
1025
+ if attn.residual_connection:
1026
+ hidden_states = hidden_states + residual
1027
+
1028
+ hidden_states = hidden_states / attn.rescale_output_factor
1029
+
1030
+ return hidden_states
1031
+
1032
+
1033
+ class UNetMidBlock2D(nn.Module):
1034
+ """
1035
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
1036
+
1037
+ Args:
1038
+ in_channels (`int`): The number of input channels.
1039
+ temb_channels (`int`): The number of temporal embedding channels.
1040
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
1041
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
1042
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
1043
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
1044
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
1045
+ model on tasks with long-range temporal dependencies.
1046
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
1047
+ resnet_groups (`int`, *optional*, defaults to 32):
1048
+ The number of groups to use in the group normalization layers of the resnet blocks.
1049
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
1050
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
1051
+ Whether to use pre-normalization for the resnet blocks.
1052
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
1053
+ attention_head_dim (`int`, *optional*, defaults to 1):
1054
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
1055
+ the number of input channels.
1056
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
1057
+
1058
+ Returns:
1059
+ `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
1060
+ height, width)`.
1061
+
1062
+ """
1063
+
1064
+ def __init__(
1065
+ self,
1066
+ in_channels: int,
1067
+ temb_channels: int,
1068
+ dropout: float = 0.0,
1069
+ num_layers: int = 1,
1070
+ resnet_eps: float = 1e-6,
1071
+ resnet_time_scale_shift: str = "default", # default, spatial
1072
+ resnet_act_fn: str = "swish",
1073
+ resnet_groups: int = 32,
1074
+ attn_groups: Optional[int] = None,
1075
+ resnet_pre_norm: bool = True,
1076
+ add_attention: bool = True,
1077
+ attention_head_dim: int = 1,
1078
+ output_scale_factor: float = 1.0,
1079
+ ):
1080
+ super().__init__()
1081
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1082
+ self.add_attention = add_attention
1083
+
1084
+ if attn_groups is None:
1085
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
1086
+
1087
+ # there is always at least one resnet
1088
+ if resnet_time_scale_shift == "spatial":
1089
+ resnets = [
1090
+ ResnetBlockCondNorm2D(
1091
+ in_channels=in_channels,
1092
+ out_channels=in_channels,
1093
+ temb_channels=temb_channels,
1094
+ eps=resnet_eps,
1095
+ groups=resnet_groups,
1096
+ dropout=dropout,
1097
+ time_embedding_norm="spatial",
1098
+ non_linearity=resnet_act_fn,
1099
+ output_scale_factor=output_scale_factor,
1100
+ )
1101
+ ]
1102
+ else:
1103
+ resnets = [
1104
+ ResnetBlock2D(
1105
+ in_channels=in_channels,
1106
+ out_channels=in_channels,
1107
+ temb_channels=temb_channels,
1108
+ eps=resnet_eps,
1109
+ groups=resnet_groups,
1110
+ dropout=dropout,
1111
+ time_embedding_norm=resnet_time_scale_shift,
1112
+ non_linearity=resnet_act_fn,
1113
+ output_scale_factor=output_scale_factor,
1114
+ pre_norm=resnet_pre_norm,
1115
+ )
1116
+ ]
1117
+ attentions = []
1118
+
1119
+ if attention_head_dim is None:
1120
+ attention_head_dim = in_channels
1121
+
1122
+ for _ in range(num_layers):
1123
+ if self.add_attention:
1124
+ attentions.append(
1125
+ Attention(
1126
+ in_channels,
1127
+ heads=in_channels // attention_head_dim,
1128
+ dim_head=attention_head_dim,
1129
+ rescale_output_factor=output_scale_factor,
1130
+ eps=resnet_eps,
1131
+ norm_num_groups=attn_groups,
1132
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
1133
+ residual_connection=True,
1134
+ bias=True,
1135
+ upcast_softmax=True,
1136
+ _from_deprecated_attn_block=True,
1137
+ )
1138
+ )
1139
+ else:
1140
+ attentions.append(None)
1141
+
1142
+ if resnet_time_scale_shift == "spatial":
1143
+ resnets.append(
1144
+ ResnetBlockCondNorm2D(
1145
+ in_channels=in_channels,
1146
+ out_channels=in_channels,
1147
+ temb_channels=temb_channels,
1148
+ eps=resnet_eps,
1149
+ groups=resnet_groups,
1150
+ dropout=dropout,
1151
+ time_embedding_norm="spatial",
1152
+ non_linearity=resnet_act_fn,
1153
+ output_scale_factor=output_scale_factor,
1154
+ )
1155
+ )
1156
+ else:
1157
+ resnets.append(
1158
+ ResnetBlock2D(
1159
+ in_channels=in_channels,
1160
+ out_channels=in_channels,
1161
+ temb_channels=temb_channels,
1162
+ eps=resnet_eps,
1163
+ groups=resnet_groups,
1164
+ dropout=dropout,
1165
+ time_embedding_norm=resnet_time_scale_shift,
1166
+ non_linearity=resnet_act_fn,
1167
+ output_scale_factor=output_scale_factor,
1168
+ pre_norm=resnet_pre_norm,
1169
+ )
1170
+ )
1171
+
1172
+ self.attentions = nn.ModuleList(attentions)
1173
+ self.resnets = nn.ModuleList(resnets)
1174
+
1175
+ self.gradient_checkpointing = False
1176
+
1177
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
1178
+ hidden_states = self.resnets[0](hidden_states, temb)
1179
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
1180
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1181
+ if attn is not None:
1182
+ hidden_states = attn(hidden_states, temb=temb)
1183
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1184
+ else:
1185
+ if attn is not None:
1186
+ hidden_states = attn(hidden_states, temb=temb)
1187
+ hidden_states = resnet(hidden_states, temb)
1188
+
1189
+ return hidden_states
1190
+
1191
+
1192
+ class DownEncoderBlock2D(nn.Module):
1193
+ def __init__(
1194
+ self,
1195
+ in_channels: int,
1196
+ out_channels: int,
1197
+ dropout: float = 0.0,
1198
+ num_layers: int = 1,
1199
+ resnet_eps: float = 1e-6,
1200
+ resnet_time_scale_shift: str = "default",
1201
+ resnet_act_fn: str = "swish",
1202
+ resnet_groups: int = 32,
1203
+ resnet_pre_norm: bool = True,
1204
+ output_scale_factor: float = 1.0,
1205
+ add_downsample: bool = True,
1206
+ downsample_padding: int = 1,
1207
+ ):
1208
+ super().__init__()
1209
+ resnets = []
1210
+
1211
+ for i in range(num_layers):
1212
+ in_channels = in_channels if i == 0 else out_channels
1213
+ if resnet_time_scale_shift == "spatial":
1214
+ resnets.append(
1215
+ ResnetBlockCondNorm2D(
1216
+ in_channels=in_channels,
1217
+ out_channels=out_channels,
1218
+ temb_channels=None,
1219
+ eps=resnet_eps,
1220
+ groups=resnet_groups,
1221
+ dropout=dropout,
1222
+ time_embedding_norm="spatial",
1223
+ non_linearity=resnet_act_fn,
1224
+ output_scale_factor=output_scale_factor,
1225
+ )
1226
+ )
1227
+ else:
1228
+ resnets.append(
1229
+ ResnetBlock2D(
1230
+ in_channels=in_channels,
1231
+ out_channels=out_channels,
1232
+ temb_channels=None,
1233
+ eps=resnet_eps,
1234
+ groups=resnet_groups,
1235
+ dropout=dropout,
1236
+ time_embedding_norm=resnet_time_scale_shift,
1237
+ non_linearity=resnet_act_fn,
1238
+ output_scale_factor=output_scale_factor,
1239
+ pre_norm=resnet_pre_norm,
1240
+ )
1241
+ )
1242
+
1243
+ self.resnets = nn.ModuleList(resnets)
1244
+
1245
+ if add_downsample:
1246
+ self.downsamplers = nn.ModuleList(
1247
+ [
1248
+ Downsample2D(
1249
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1250
+ )
1251
+ ]
1252
+ )
1253
+ else:
1254
+ self.downsamplers = None
1255
+
1256
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1257
+ for resnet in self.resnets:
1258
+ hidden_states = resnet(hidden_states, temb=None)
1259
+
1260
+ if self.downsamplers is not None:
1261
+ for downsampler in self.downsamplers:
1262
+ hidden_states = downsampler(hidden_states)
1263
+
1264
+ return hidden_states
1265
+
1266
+
1267
+ class UpDecoderBlock2D(nn.Module):
1268
+ def __init__(
1269
+ self,
1270
+ in_channels: int,
1271
+ out_channels: int,
1272
+ resolution_idx: Optional[int] = None,
1273
+ dropout: float = 0.0,
1274
+ num_layers: int = 1,
1275
+ resnet_eps: float = 1e-6,
1276
+ resnet_time_scale_shift: str = "default", # default, spatial
1277
+ resnet_act_fn: str = "swish",
1278
+ resnet_groups: int = 32,
1279
+ resnet_pre_norm: bool = True,
1280
+ output_scale_factor: float = 1.0,
1281
+ add_upsample: bool = True,
1282
+ temb_channels: Optional[int] = None,
1283
+ ):
1284
+ super().__init__()
1285
+ resnets = []
1286
+
1287
+ for i in range(num_layers):
1288
+ input_channels = in_channels if i == 0 else out_channels
1289
+
1290
+ if resnet_time_scale_shift == "spatial":
1291
+ resnets.append(
1292
+ ResnetBlockCondNorm2D(
1293
+ in_channels=input_channels,
1294
+ out_channels=out_channels,
1295
+ temb_channels=temb_channels,
1296
+ eps=resnet_eps,
1297
+ groups=resnet_groups,
1298
+ dropout=dropout,
1299
+ time_embedding_norm="spatial",
1300
+ non_linearity=resnet_act_fn,
1301
+ output_scale_factor=output_scale_factor,
1302
+ )
1303
+ )
1304
+ else:
1305
+ resnets.append(
1306
+ ResnetBlock2D(
1307
+ in_channels=input_channels,
1308
+ out_channels=out_channels,
1309
+ temb_channels=temb_channels,
1310
+ eps=resnet_eps,
1311
+ groups=resnet_groups,
1312
+ dropout=dropout,
1313
+ time_embedding_norm=resnet_time_scale_shift,
1314
+ non_linearity=resnet_act_fn,
1315
+ output_scale_factor=output_scale_factor,
1316
+ pre_norm=resnet_pre_norm,
1317
+ )
1318
+ )
1319
+
1320
+ self.resnets = nn.ModuleList(resnets)
1321
+
1322
+ if add_upsample:
1323
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1324
+ else:
1325
+ self.upsamplers = None
1326
+
1327
+ self.resolution_idx = resolution_idx
1328
+
1329
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
1330
+ for resnet in self.resnets:
1331
+ hidden_states = resnet(hidden_states, temb=temb)
1332
+
1333
+ if self.upsamplers is not None:
1334
+ for upsampler in self.upsamplers:
1335
+ hidden_states = upsampler(hidden_states)
1336
+
1337
+ return hidden_states
1338
+
1339
+
1340
+ class Encoder(nn.Module):
1341
+ r"""
1342
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
1343
+
1344
+ Args:
1345
+ in_channels (`int`, *optional*, defaults to 3):
1346
+ The number of input channels.
1347
+ out_channels (`int`, *optional*, defaults to 3):
1348
+ The number of output channels.
1349
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1350
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
1351
+ options.
1352
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
1353
+ The number of output channels for each block.
1354
+ layers_per_block (`int`, *optional*, defaults to 2):
1355
+ The number of layers per block.
1356
+ norm_num_groups (`int`, *optional*, defaults to 32):
1357
+ The number of groups for normalization.
1358
+ act_fn (`str`, *optional*, defaults to `"silu"`):
1359
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
1360
+ double_z (`bool`, *optional*, defaults to `True`):
1361
+ Whether to double the number of output channels for the last block.
1362
+ """
1363
+
1364
+ def __init__(
1365
+ self,
1366
+ in_channels: int = 3,
1367
+ out_channels: int = 3,
1368
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
1369
+ block_out_channels: Tuple[int, ...] = (64,),
1370
+ layers_per_block: int = 2,
1371
+ norm_num_groups: int = 32,
1372
+ act_fn: str = "silu",
1373
+ double_z: bool = True,
1374
+ mid_block_add_attention=True,
1375
+ ):
1376
+ super().__init__()
1377
+ self.layers_per_block = layers_per_block
1378
+
1379
+ self.conv_in = nn.Conv2d(
1380
+ in_channels,
1381
+ block_out_channels[0],
1382
+ kernel_size=3,
1383
+ stride=1,
1384
+ padding=1,
1385
+ )
1386
+
1387
+ self.down_blocks = nn.ModuleList([])
1388
+
1389
+ # down
1390
+ output_channel = block_out_channels[0]
1391
+ for i, down_block_type in enumerate(down_block_types):
1392
+ input_channel = output_channel
1393
+ output_channel = block_out_channels[i]
1394
+ is_final_block = i == len(block_out_channels) - 1
1395
+
1396
+ down_block = DownEncoderBlock2D(
1397
+ num_layers=self.layers_per_block,
1398
+ in_channels=input_channel,
1399
+ out_channels=output_channel,
1400
+ add_downsample=not is_final_block,
1401
+ resnet_eps=1e-6,
1402
+ downsample_padding=0,
1403
+ resnet_act_fn=act_fn,
1404
+ resnet_groups=norm_num_groups,
1405
+ # attention_head_dim=output_channel,
1406
+ # temb_channels=None,
1407
+ )
1408
+ self.down_blocks.append(down_block)
1409
+
1410
+ # mid
1411
+ self.mid_block = UNetMidBlock2D(
1412
+ in_channels=block_out_channels[-1],
1413
+ resnet_eps=1e-6,
1414
+ resnet_act_fn=act_fn,
1415
+ output_scale_factor=1,
1416
+ resnet_time_scale_shift="default",
1417
+ attention_head_dim=block_out_channels[-1],
1418
+ resnet_groups=norm_num_groups,
1419
+ temb_channels=None,
1420
+ add_attention=mid_block_add_attention,
1421
+ )
1422
+
1423
+ # out
1424
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
1425
+ self.conv_act = nn.SiLU()
1426
+
1427
+ conv_out_channels = 2 * out_channels if double_z else out_channels
1428
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
1429
+
1430
+ self.gradient_checkpointing = False
1431
+
1432
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
1433
+ r"""The forward method of the `Encoder` class."""
1434
+
1435
+ sample = self.conv_in(sample)
1436
+
1437
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1438
+ # down
1439
+ for down_block in self.down_blocks:
1440
+ sample = self._gradient_checkpointing_func(down_block, sample)
1441
+ # middle
1442
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
1443
+
1444
+ else:
1445
+ # down
1446
+ for down_block in self.down_blocks:
1447
+ sample = down_block(sample)
1448
+
1449
+ # middle
1450
+ sample = self.mid_block(sample)
1451
+
1452
+ # post-process
1453
+ sample = self.conv_norm_out(sample)
1454
+ sample = self.conv_act(sample)
1455
+ sample = self.conv_out(sample)
1456
+
1457
+ return sample
1458
+
1459
+
1460
+ class Decoder(nn.Module):
1461
+ r"""
1462
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
1463
+
1464
+ Args:
1465
+ in_channels (`int`, *optional*, defaults to 3):
1466
+ The number of input channels.
1467
+ out_channels (`int`, *optional*, defaults to 3):
1468
+ The number of output channels.
1469
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1470
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
1471
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
1472
+ The number of output channels for each block.
1473
+ layers_per_block (`int`, *optional*, defaults to 2):
1474
+ The number of layers per block.
1475
+ norm_num_groups (`int`, *optional*, defaults to 32):
1476
+ The number of groups for normalization.
1477
+ act_fn (`str`, *optional*, defaults to `"silu"`):
1478
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
1479
+ norm_type (`str`, *optional*, defaults to `"group"`):
1480
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
1481
+ """
1482
+
1483
+ def __init__(
1484
+ self,
1485
+ in_channels: int = 3,
1486
+ out_channels: int = 3,
1487
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
1488
+ block_out_channels: Tuple[int, ...] = (64,),
1489
+ layers_per_block: int = 2,
1490
+ norm_num_groups: int = 32,
1491
+ act_fn: str = "silu",
1492
+ norm_type: str = "group", # group, spatial
1493
+ mid_block_add_attention=True,
1494
+ ):
1495
+ super().__init__()
1496
+ self.layers_per_block = layers_per_block
1497
+
1498
+ self.conv_in = nn.Conv2d(
1499
+ in_channels,
1500
+ block_out_channels[-1],
1501
+ kernel_size=3,
1502
+ stride=1,
1503
+ padding=1,
1504
+ )
1505
+
1506
+ self.up_blocks = nn.ModuleList([])
1507
+
1508
+ temb_channels = in_channels if norm_type == "spatial" else None
1509
+
1510
+ # mid
1511
+ self.mid_block = UNetMidBlock2D(
1512
+ in_channels=block_out_channels[-1],
1513
+ resnet_eps=1e-6,
1514
+ resnet_act_fn=act_fn,
1515
+ output_scale_factor=1,
1516
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
1517
+ attention_head_dim=block_out_channels[-1],
1518
+ resnet_groups=norm_num_groups,
1519
+ temb_channels=temb_channels,
1520
+ add_attention=mid_block_add_attention,
1521
+ )
1522
+
1523
+ # up
1524
+ reversed_block_out_channels = list(reversed(block_out_channels))
1525
+ output_channel = reversed_block_out_channels[0]
1526
+ for i, up_block_type in enumerate(up_block_types):
1527
+ prev_output_channel = output_channel
1528
+ output_channel = reversed_block_out_channels[i]
1529
+
1530
+ is_final_block = i == len(block_out_channels) - 1
1531
+
1532
+ up_block = UpDecoderBlock2D(
1533
+ num_layers=self.layers_per_block + 1,
1534
+ in_channels=prev_output_channel,
1535
+ out_channels=output_channel,
1536
+ # prev_output_channel=prev_output_channel,
1537
+ add_upsample=not is_final_block,
1538
+ resnet_eps=1e-6,
1539
+ resnet_act_fn=act_fn,
1540
+ resnet_groups=norm_num_groups,
1541
+ # attention_head_dim=output_channel,
1542
+ temb_channels=temb_channels,
1543
+ resnet_time_scale_shift=norm_type,
1544
+ )
1545
+ self.up_blocks.append(up_block)
1546
+ prev_output_channel = output_channel
1547
+
1548
+ # out
1549
+ if norm_type == "spatial":
1550
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
1551
+ else:
1552
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
1553
+ self.conv_act = nn.SiLU()
1554
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
1555
+
1556
+ self.gradient_checkpointing = False
1557
+
1558
+ def forward(
1559
+ self,
1560
+ sample: torch.Tensor,
1561
+ latent_embeds: Optional[torch.Tensor] = None,
1562
+ ) -> torch.Tensor:
1563
+ r"""The forward method of the `Decoder` class."""
1564
+
1565
+ sample = self.conv_in(sample)
1566
+
1567
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1568
+ # middle
1569
+ sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
1570
+
1571
+ # up
1572
+ for up_block in self.up_blocks:
1573
+ sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
1574
+ else:
1575
+ # middle
1576
+ sample = self.mid_block(sample, latent_embeds)
1577
+
1578
+ # up
1579
+ for up_block in self.up_blocks:
1580
+ sample = up_block(sample, latent_embeds)
1581
+
1582
+ # post-process
1583
+ if latent_embeds is None:
1584
+ sample = self.conv_norm_out(sample)
1585
+ else:
1586
+ sample = self.conv_norm_out(sample, latent_embeds)
1587
+ sample = self.conv_act(sample)
1588
+ sample = self.conv_out(sample)
1589
+
1590
+ return sample
1591
+
1592
+
1593
+ class Flux2VAE(PreTrainedModel):
1594
+ r"""
1595
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
1596
+
1597
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1598
+ for all models (such as downloading or saving).
1599
+
1600
+ Parameters:
1601
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1602
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1603
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1604
+ Tuple of downsample block types.
1605
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1606
+ Tuple of upsample block types.
1607
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1608
+ Tuple of block output channels.
1609
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1610
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
1611
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1612
+ force_upcast (`bool`, *optional*, default to `True`):
1613
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1614
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
1615
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1616
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
1617
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
1618
+ mid_block will only have resnet blocks
1619
+ """
1620
+
1621
+ _supports_gradient_checkpointing = True
1622
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
1623
+
1624
+ def __init__(
1625
+ self,
1626
+ in_channels: int = 3,
1627
+ out_channels: int = 3,
1628
+ down_block_types: Tuple[str, ...] = (
1629
+ "DownEncoderBlock2D",
1630
+ "DownEncoderBlock2D",
1631
+ "DownEncoderBlock2D",
1632
+ "DownEncoderBlock2D",
1633
+ ),
1634
+ up_block_types: Tuple[str, ...] = (
1635
+ "UpDecoderBlock2D",
1636
+ "UpDecoderBlock2D",
1637
+ "UpDecoderBlock2D",
1638
+ "UpDecoderBlock2D",
1639
+ ),
1640
+ block_out_channels: Tuple[int, ...] = (
1641
+ 128,
1642
+ 256,
1643
+ 512,
1644
+ 512,
1645
+ ),
1646
+ layers_per_block: int = 2,
1647
+ act_fn: str = "silu",
1648
+ latent_channels: int = 32,
1649
+ norm_num_groups: int = 32,
1650
+ sample_size: int = 1024, # YiYi notes: not sure
1651
+ force_upcast: bool = True,
1652
+ use_quant_conv: bool = True,
1653
+ use_post_quant_conv: bool = True,
1654
+ mid_block_add_attention: bool = True,
1655
+ batch_norm_eps: float = 1e-4,
1656
+ batch_norm_momentum: float = 0.1,
1657
+ patch_size: Tuple[int, int] = (2, 2),
1658
+ device: str = "cuda:0",
1659
+ dtype: torch.dtype = torch.float32,
1660
+ ):
1661
+ super().__init__()
1662
+
1663
+ # pass init params to Encoder
1664
+ self.encoder = Encoder(
1665
+ in_channels=in_channels,
1666
+ out_channels=latent_channels,
1667
+ down_block_types=down_block_types,
1668
+ block_out_channels=block_out_channels,
1669
+ layers_per_block=layers_per_block,
1670
+ act_fn=act_fn,
1671
+ norm_num_groups=norm_num_groups,
1672
+ double_z=True,
1673
+ mid_block_add_attention=mid_block_add_attention,
1674
+ )
1675
+
1676
+ # pass init params to Decoder
1677
+ self.decoder = Decoder(
1678
+ in_channels=latent_channels,
1679
+ out_channels=out_channels,
1680
+ up_block_types=up_block_types,
1681
+ block_out_channels=block_out_channels,
1682
+ layers_per_block=layers_per_block,
1683
+ norm_num_groups=norm_num_groups,
1684
+ act_fn=act_fn,
1685
+ mid_block_add_attention=mid_block_add_attention,
1686
+ )
1687
+
1688
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
1689
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
1690
+
1691
+ self.bn = nn.BatchNorm2d(
1692
+ math.prod(patch_size) * latent_channels,
1693
+ eps=batch_norm_eps,
1694
+ momentum=batch_norm_momentum,
1695
+ affine=False,
1696
+ track_running_stats=True,
1697
+ )
1698
+
1699
+ self.use_slicing = False
1700
+ self.use_tiling = False
1701
+
1702
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1703
+ batch_size, num_channels, height, width = x.shape
1704
+
1705
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
1706
+ return self._tiled_encode(x)
1707
+
1708
+ enc = self.encoder(x)
1709
+ if self.quant_conv is not None:
1710
+ enc = self.quant_conv(enc)
1711
+
1712
+ return enc
1713
+
1714
+ def encode(
1715
+ self, x: torch.Tensor, return_dict: bool = True
1716
+ ):
1717
+ """
1718
+ Encode a batch of images into latents.
1719
+
1720
+ Args:
1721
+ x (`torch.Tensor`): Input batch of images.
1722
+ return_dict (`bool`, *optional*, defaults to `True`):
1723
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1724
+
1725
+ Returns:
1726
+ The latent representations of the encoded images. If `return_dict` is True, a
1727
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1728
+ """
1729
+ if self.use_slicing and x.shape[0] > 1:
1730
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1731
+ h = torch.cat(encoded_slices)
1732
+ else:
1733
+ h = self._encode(x)
1734
+
1735
+ h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2)
1736
+ h = h[:, :128]
1737
+ latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype)
1738
+ latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
1739
+ h.device, h.dtype
1740
+ )
1741
+ h = (h - latents_bn_mean) / latents_bn_std
1742
+ return h
1743
+
1744
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
1745
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
1746
+ return self.tiled_decode(z, return_dict=return_dict)
1747
+
1748
+ if self.post_quant_conv is not None:
1749
+ z = self.post_quant_conv(z)
1750
+
1751
+ dec = self.decoder(z)
1752
+
1753
+ if not return_dict:
1754
+ return (dec,)
1755
+
1756
+ return dec
1757
+
1758
+ def decode(
1759
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
1760
+ ):
1761
+ latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype)
1762
+ latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
1763
+ z.device, z.dtype
1764
+ )
1765
+ z = z * latents_bn_std + latents_bn_mean
1766
+ z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2)
1767
+ """
1768
+ Decode a batch of images.
1769
+
1770
+ Args:
1771
+ z (`torch.Tensor`): Input batch of latent vectors.
1772
+ return_dict (`bool`, *optional*, defaults to `True`):
1773
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1774
+
1775
+ Returns:
1776
+ [`~models.vae.DecoderOutput`] or `tuple`:
1777
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1778
+ returned.
1779
+
1780
+ """
1781
+ if self.use_slicing and z.shape[0] > 1:
1782
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
1783
+ decoded = torch.cat(decoded_slices)
1784
+ else:
1785
+ decoded = self._decode(z)
1786
+
1787
+ if not return_dict:
1788
+ return (decoded,)
1789
+
1790
+ return decoded
1791
+
1792
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1793
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
1794
+ for y in range(blend_extent):
1795
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
1796
+ return b
1797
+
1798
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1799
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1800
+ for x in range(blend_extent):
1801
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
1802
+ return b
1803
+
1804
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1805
+ r"""Encode a batch of images using a tiled encoder.
1806
+
1807
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1808
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1809
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1810
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1811
+ output, but they should be much less noticeable.
1812
+
1813
+ Args:
1814
+ x (`torch.Tensor`): Input batch of images.
1815
+
1816
+ Returns:
1817
+ `torch.Tensor`:
1818
+ The latent representation of the encoded videos.
1819
+ """
1820
+
1821
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
1822
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
1823
+ row_limit = self.tile_latent_min_size - blend_extent
1824
+
1825
+ # Split the image into 512x512 tiles and encode them separately.
1826
+ rows = []
1827
+ for i in range(0, x.shape[2], overlap_size):
1828
+ row = []
1829
+ for j in range(0, x.shape[3], overlap_size):
1830
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
1831
+ tile = self.encoder(tile)
1832
+ if self.config.use_quant_conv:
1833
+ tile = self.quant_conv(tile)
1834
+ row.append(tile)
1835
+ rows.append(row)
1836
+ result_rows = []
1837
+ for i, row in enumerate(rows):
1838
+ result_row = []
1839
+ for j, tile in enumerate(row):
1840
+ # blend the above tile and the left tile
1841
+ # to the current tile and add the current tile to the result row
1842
+ if i > 0:
1843
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
1844
+ if j > 0:
1845
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1846
+ result_row.append(tile[:, :, :row_limit, :row_limit])
1847
+ result_rows.append(torch.cat(result_row, dim=3))
1848
+
1849
+ enc = torch.cat(result_rows, dim=2)
1850
+ return enc
1851
+
1852
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True):
1853
+ r"""Encode a batch of images using a tiled encoder.
1854
+
1855
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1856
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1857
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1858
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1859
+ output, but they should be much less noticeable.
1860
+
1861
+ Args:
1862
+ x (`torch.Tensor`): Input batch of images.
1863
+ return_dict (`bool`, *optional*, defaults to `True`):
1864
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1865
+
1866
+ Returns:
1867
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
1868
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
1869
+ `tuple` is returned.
1870
+ """
1871
+
1872
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
1873
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
1874
+ row_limit = self.tile_latent_min_size - blend_extent
1875
+
1876
+ # Split the image into 512x512 tiles and encode them separately.
1877
+ rows = []
1878
+ for i in range(0, x.shape[2], overlap_size):
1879
+ row = []
1880
+ for j in range(0, x.shape[3], overlap_size):
1881
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
1882
+ tile = self.encoder(tile)
1883
+ if self.config.use_quant_conv:
1884
+ tile = self.quant_conv(tile)
1885
+ row.append(tile)
1886
+ rows.append(row)
1887
+ result_rows = []
1888
+ for i, row in enumerate(rows):
1889
+ result_row = []
1890
+ for j, tile in enumerate(row):
1891
+ # blend the above tile and the left tile
1892
+ # to the current tile and add the current tile to the result row
1893
+ if i > 0:
1894
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
1895
+ if j > 0:
1896
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1897
+ result_row.append(tile[:, :, :row_limit, :row_limit])
1898
+ result_rows.append(torch.cat(result_row, dim=3))
1899
+
1900
+ moments = torch.cat(result_rows, dim=2)
1901
+ return moments
1902
+
1903
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
1904
+ r"""
1905
+ Decode a batch of images using a tiled decoder.
1906
+
1907
+ Args:
1908
+ z (`torch.Tensor`): Input batch of latent vectors.
1909
+ return_dict (`bool`, *optional*, defaults to `True`):
1910
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1911
+
1912
+ Returns:
1913
+ [`~models.vae.DecoderOutput`] or `tuple`:
1914
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1915
+ returned.
1916
+ """
1917
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
1918
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
1919
+ row_limit = self.tile_sample_min_size - blend_extent
1920
+
1921
+ # Split z into overlapping 64x64 tiles and decode them separately.
1922
+ # The tiles have an overlap to avoid seams between tiles.
1923
+ rows = []
1924
+ for i in range(0, z.shape[2], overlap_size):
1925
+ row = []
1926
+ for j in range(0, z.shape[3], overlap_size):
1927
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
1928
+ if self.config.use_post_quant_conv:
1929
+ tile = self.post_quant_conv(tile)
1930
+ decoded = self.decoder(tile)
1931
+ row.append(decoded)
1932
+ rows.append(row)
1933
+ result_rows = []
1934
+ for i, row in enumerate(rows):
1935
+ result_row = []
1936
+ for j, tile in enumerate(row):
1937
+ # blend the above tile and the left tile
1938
+ # to the current tile and add the current tile to the result row
1939
+ if i > 0:
1940
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
1941
+ if j > 0:
1942
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
1943
+ result_row.append(tile[:, :, :row_limit, :row_limit])
1944
+ result_rows.append(torch.cat(result_row, dim=3))
1945
+
1946
+ dec = torch.cat(result_rows, dim=2)
1947
+ if not return_dict:
1948
+ return (dec,)
1949
+
1950
+ return dec
1951
+
1952
+ def forward(
1953
+ self,
1954
+ sample: torch.Tensor,
1955
+ sample_posterior: bool = False,
1956
+ return_dict: bool = True,
1957
+ generator: Optional[torch.Generator] = None,
1958
+ ):
1959
+ r"""
1960
+ Args:
1961
+ sample (`torch.Tensor`): Input sample.
1962
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1963
+ Whether to sample from the posterior.
1964
+ return_dict (`bool`, *optional*, defaults to `True`):
1965
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1966
+ """
1967
+ x = sample
1968
+ posterior = self.encode(x).latent_dist
1969
+ if sample_posterior:
1970
+ z = posterior.sample(generator=generator)
1971
+ else:
1972
+ z = posterior.mode()
1973
+ dec = self.decode(z).sample
1974
+
1975
+ if not return_dict:
1976
+ return (dec,)
1977
+
1978
+ return dec
1979
+
1980
+ @classmethod
1981
+ def from_state_dict(
1982
+ cls,
1983
+ state_dict: Dict[str, torch.Tensor],
1984
+ device: str = "cuda:0",
1985
+ dtype: torch.dtype = torch.float32,
1986
+ **kwargs,
1987
+ ) -> "Flux2VAE":
1988
+ model = cls(device="meta", dtype=dtype, **kwargs)
1989
+ model = model.requires_grad_(False)
1990
+ model.load_state_dict(state_dict, assign=True)
1991
+ model.to(device=device, dtype=dtype, non_blocking=True)
1992
+ return model