diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -0
- diffsynth_engine/conf/models/flux2/qwen3_8B_config.json +68 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +50 -1
- diffsynth_engine/models/flux2/__init__.py +7 -0
- diffsynth_engine/models/flux2/flux2_dit.py +1065 -0
- diffsynth_engine/models/flux2/flux2_vae.py +1992 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/flux2_klein_image.py +634 -0
- diffsynth_engine/utils/constants.py +1 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/RECORD +15 -10
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/WHEEL +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1992 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Dict, Optional, Tuple, Union, Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from diffsynth_engine.models.base import PreTrainedModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
ACT2CLS = {
|
|
13
|
+
"swish": nn.SiLU,
|
|
14
|
+
"silu": nn.SiLU,
|
|
15
|
+
"mish": nn.Mish,
|
|
16
|
+
"gelu": nn.GELU,
|
|
17
|
+
"relu": nn.ReLU,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_activation(act_fn: str) -> nn.Module:
|
|
22
|
+
"""Helper function to get activation function from string.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
act_fn (str): Name of activation function.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
nn.Module: Activation function.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
act_fn = act_fn.lower()
|
|
32
|
+
if act_fn in ACT2CLS:
|
|
33
|
+
return ACT2CLS[act_fn]()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ResnetBlock2D(nn.Module):
|
|
39
|
+
r"""
|
|
40
|
+
A Resnet block.
|
|
41
|
+
|
|
42
|
+
Parameters:
|
|
43
|
+
in_channels (`int`): The number of channels in the input.
|
|
44
|
+
out_channels (`int`, *optional*, default to be `None`):
|
|
45
|
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
|
46
|
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
|
47
|
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
|
48
|
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
|
49
|
+
groups_out (`int`, *optional*, default to None):
|
|
50
|
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
|
51
|
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
|
52
|
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
|
53
|
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
|
54
|
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
|
|
55
|
+
stronger conditioning with scale and shift.
|
|
56
|
+
kernel (`torch.Tensor`, optional, default to None): FIR filter, see
|
|
57
|
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
|
58
|
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
|
59
|
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
|
60
|
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
|
61
|
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
|
62
|
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
|
63
|
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
|
64
|
+
`conv_shortcut` output.
|
|
65
|
+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
|
66
|
+
If None, same as `out_channels`.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
*,
|
|
72
|
+
in_channels: int,
|
|
73
|
+
out_channels: Optional[int] = None,
|
|
74
|
+
conv_shortcut: bool = False,
|
|
75
|
+
dropout: float = 0.0,
|
|
76
|
+
temb_channels: int = 512,
|
|
77
|
+
groups: int = 32,
|
|
78
|
+
groups_out: Optional[int] = None,
|
|
79
|
+
pre_norm: bool = True,
|
|
80
|
+
eps: float = 1e-6,
|
|
81
|
+
non_linearity: str = "swish",
|
|
82
|
+
skip_time_act: bool = False,
|
|
83
|
+
time_embedding_norm: str = "default", # default, scale_shift,
|
|
84
|
+
kernel: Optional[torch.Tensor] = None,
|
|
85
|
+
output_scale_factor: float = 1.0,
|
|
86
|
+
use_in_shortcut: Optional[bool] = None,
|
|
87
|
+
up: bool = False,
|
|
88
|
+
down: bool = False,
|
|
89
|
+
conv_shortcut_bias: bool = True,
|
|
90
|
+
conv_2d_out_channels: Optional[int] = None,
|
|
91
|
+
):
|
|
92
|
+
super().__init__()
|
|
93
|
+
|
|
94
|
+
self.pre_norm = True
|
|
95
|
+
self.in_channels = in_channels
|
|
96
|
+
out_channels = in_channels if out_channels is None else out_channels
|
|
97
|
+
self.out_channels = out_channels
|
|
98
|
+
self.use_conv_shortcut = conv_shortcut
|
|
99
|
+
self.up = up
|
|
100
|
+
self.down = down
|
|
101
|
+
self.output_scale_factor = output_scale_factor
|
|
102
|
+
self.time_embedding_norm = time_embedding_norm
|
|
103
|
+
self.skip_time_act = skip_time_act
|
|
104
|
+
|
|
105
|
+
if groups_out is None:
|
|
106
|
+
groups_out = groups
|
|
107
|
+
|
|
108
|
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
109
|
+
|
|
110
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
111
|
+
|
|
112
|
+
if temb_channels is not None:
|
|
113
|
+
if self.time_embedding_norm == "default":
|
|
114
|
+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
|
115
|
+
elif self.time_embedding_norm == "scale_shift":
|
|
116
|
+
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
|
119
|
+
else:
|
|
120
|
+
self.time_emb_proj = None
|
|
121
|
+
|
|
122
|
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
|
123
|
+
|
|
124
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
125
|
+
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
|
126
|
+
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
|
127
|
+
|
|
128
|
+
self.nonlinearity = get_activation(non_linearity)
|
|
129
|
+
|
|
130
|
+
self.upsample = self.downsample = None
|
|
131
|
+
if self.up:
|
|
132
|
+
if kernel == "fir":
|
|
133
|
+
fir_kernel = (1, 3, 3, 1)
|
|
134
|
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
|
135
|
+
elif kernel == "sde_vp":
|
|
136
|
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
|
137
|
+
else:
|
|
138
|
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
|
139
|
+
elif self.down:
|
|
140
|
+
if kernel == "fir":
|
|
141
|
+
fir_kernel = (1, 3, 3, 1)
|
|
142
|
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
|
143
|
+
elif kernel == "sde_vp":
|
|
144
|
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
|
145
|
+
else:
|
|
146
|
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
|
147
|
+
|
|
148
|
+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
|
149
|
+
|
|
150
|
+
self.conv_shortcut = None
|
|
151
|
+
if self.use_in_shortcut:
|
|
152
|
+
self.conv_shortcut = nn.Conv2d(
|
|
153
|
+
in_channels,
|
|
154
|
+
conv_2d_out_channels,
|
|
155
|
+
kernel_size=1,
|
|
156
|
+
stride=1,
|
|
157
|
+
padding=0,
|
|
158
|
+
bias=conv_shortcut_bias,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
162
|
+
hidden_states = input_tensor
|
|
163
|
+
|
|
164
|
+
hidden_states = self.norm1(hidden_states)
|
|
165
|
+
hidden_states = self.nonlinearity(hidden_states)
|
|
166
|
+
|
|
167
|
+
if self.upsample is not None:
|
|
168
|
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
|
169
|
+
if hidden_states.shape[0] >= 64:
|
|
170
|
+
input_tensor = input_tensor.contiguous()
|
|
171
|
+
hidden_states = hidden_states.contiguous()
|
|
172
|
+
input_tensor = self.upsample(input_tensor)
|
|
173
|
+
hidden_states = self.upsample(hidden_states)
|
|
174
|
+
elif self.downsample is not None:
|
|
175
|
+
input_tensor = self.downsample(input_tensor)
|
|
176
|
+
hidden_states = self.downsample(hidden_states)
|
|
177
|
+
|
|
178
|
+
hidden_states = self.conv1(hidden_states)
|
|
179
|
+
|
|
180
|
+
if self.time_emb_proj is not None:
|
|
181
|
+
if not self.skip_time_act:
|
|
182
|
+
temb = self.nonlinearity(temb)
|
|
183
|
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
|
184
|
+
|
|
185
|
+
if self.time_embedding_norm == "default":
|
|
186
|
+
if temb is not None:
|
|
187
|
+
hidden_states = hidden_states + temb
|
|
188
|
+
hidden_states = self.norm2(hidden_states)
|
|
189
|
+
elif self.time_embedding_norm == "scale_shift":
|
|
190
|
+
if temb is None:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
|
|
193
|
+
)
|
|
194
|
+
time_scale, time_shift = torch.chunk(temb, 2, dim=1)
|
|
195
|
+
hidden_states = self.norm2(hidden_states)
|
|
196
|
+
hidden_states = hidden_states * (1 + time_scale) + time_shift
|
|
197
|
+
else:
|
|
198
|
+
hidden_states = self.norm2(hidden_states)
|
|
199
|
+
|
|
200
|
+
hidden_states = self.nonlinearity(hidden_states)
|
|
201
|
+
|
|
202
|
+
hidden_states = self.dropout(hidden_states)
|
|
203
|
+
hidden_states = self.conv2(hidden_states)
|
|
204
|
+
|
|
205
|
+
if self.conv_shortcut is not None:
|
|
206
|
+
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
|
207
|
+
|
|
208
|
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
|
209
|
+
|
|
210
|
+
return output_tensor
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class Downsample2D(nn.Module):
|
|
214
|
+
"""A 2D downsampling layer with an optional convolution.
|
|
215
|
+
|
|
216
|
+
Parameters:
|
|
217
|
+
channels (`int`):
|
|
218
|
+
number of channels in the inputs and outputs.
|
|
219
|
+
use_conv (`bool`, default `False`):
|
|
220
|
+
option to use a convolution.
|
|
221
|
+
out_channels (`int`, optional):
|
|
222
|
+
number of output channels. Defaults to `channels`.
|
|
223
|
+
padding (`int`, default `1`):
|
|
224
|
+
padding for the convolution.
|
|
225
|
+
name (`str`, default `conv`):
|
|
226
|
+
name of the downsampling 2D layer.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
channels: int,
|
|
232
|
+
use_conv: bool = False,
|
|
233
|
+
out_channels: Optional[int] = None,
|
|
234
|
+
padding: int = 1,
|
|
235
|
+
name: str = "conv",
|
|
236
|
+
kernel_size=3,
|
|
237
|
+
norm_type=None,
|
|
238
|
+
eps=None,
|
|
239
|
+
elementwise_affine=None,
|
|
240
|
+
bias=True,
|
|
241
|
+
):
|
|
242
|
+
super().__init__()
|
|
243
|
+
self.channels = channels
|
|
244
|
+
self.out_channels = out_channels or channels
|
|
245
|
+
self.use_conv = use_conv
|
|
246
|
+
self.padding = padding
|
|
247
|
+
stride = 2
|
|
248
|
+
self.name = name
|
|
249
|
+
|
|
250
|
+
if norm_type == "ln_norm":
|
|
251
|
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
|
252
|
+
elif norm_type == "rms_norm":
|
|
253
|
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
|
254
|
+
elif norm_type is None:
|
|
255
|
+
self.norm = None
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
|
258
|
+
|
|
259
|
+
if use_conv:
|
|
260
|
+
conv = nn.Conv2d(
|
|
261
|
+
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
assert self.channels == self.out_channels
|
|
265
|
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
|
266
|
+
|
|
267
|
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
268
|
+
if name == "conv":
|
|
269
|
+
self.Conv2d_0 = conv
|
|
270
|
+
self.conv = conv
|
|
271
|
+
elif name == "Conv2d_0":
|
|
272
|
+
self.conv = conv
|
|
273
|
+
else:
|
|
274
|
+
self.conv = conv
|
|
275
|
+
|
|
276
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
277
|
+
assert hidden_states.shape[1] == self.channels
|
|
278
|
+
|
|
279
|
+
if self.norm is not None:
|
|
280
|
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
281
|
+
|
|
282
|
+
if self.use_conv and self.padding == 0:
|
|
283
|
+
pad = (0, 1, 0, 1)
|
|
284
|
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
|
285
|
+
|
|
286
|
+
assert hidden_states.shape[1] == self.channels
|
|
287
|
+
|
|
288
|
+
hidden_states = self.conv(hidden_states)
|
|
289
|
+
|
|
290
|
+
return hidden_states
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Upsample2D(nn.Module):
|
|
294
|
+
"""A 2D upsampling layer with an optional convolution.
|
|
295
|
+
|
|
296
|
+
Parameters:
|
|
297
|
+
channels (`int`):
|
|
298
|
+
number of channels in the inputs and outputs.
|
|
299
|
+
use_conv (`bool`, default `False`):
|
|
300
|
+
option to use a convolution.
|
|
301
|
+
use_conv_transpose (`bool`, default `False`):
|
|
302
|
+
option to use a convolution transpose.
|
|
303
|
+
out_channels (`int`, optional):
|
|
304
|
+
number of output channels. Defaults to `channels`.
|
|
305
|
+
name (`str`, default `conv`):
|
|
306
|
+
name of the upsampling 2D layer.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(
|
|
310
|
+
self,
|
|
311
|
+
channels: int,
|
|
312
|
+
use_conv: bool = False,
|
|
313
|
+
use_conv_transpose: bool = False,
|
|
314
|
+
out_channels: Optional[int] = None,
|
|
315
|
+
name: str = "conv",
|
|
316
|
+
kernel_size: Optional[int] = None,
|
|
317
|
+
padding=1,
|
|
318
|
+
norm_type=None,
|
|
319
|
+
eps=None,
|
|
320
|
+
elementwise_affine=None,
|
|
321
|
+
bias=True,
|
|
322
|
+
interpolate=True,
|
|
323
|
+
):
|
|
324
|
+
super().__init__()
|
|
325
|
+
self.channels = channels
|
|
326
|
+
self.out_channels = out_channels or channels
|
|
327
|
+
self.use_conv = use_conv
|
|
328
|
+
self.use_conv_transpose = use_conv_transpose
|
|
329
|
+
self.name = name
|
|
330
|
+
self.interpolate = interpolate
|
|
331
|
+
|
|
332
|
+
if norm_type == "ln_norm":
|
|
333
|
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
|
334
|
+
elif norm_type == "rms_norm":
|
|
335
|
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
|
336
|
+
elif norm_type is None:
|
|
337
|
+
self.norm = None
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
|
340
|
+
|
|
341
|
+
conv = None
|
|
342
|
+
if use_conv_transpose:
|
|
343
|
+
if kernel_size is None:
|
|
344
|
+
kernel_size = 4
|
|
345
|
+
conv = nn.ConvTranspose2d(
|
|
346
|
+
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
|
347
|
+
)
|
|
348
|
+
elif use_conv:
|
|
349
|
+
if kernel_size is None:
|
|
350
|
+
kernel_size = 3
|
|
351
|
+
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
|
352
|
+
|
|
353
|
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
354
|
+
if name == "conv":
|
|
355
|
+
self.conv = conv
|
|
356
|
+
else:
|
|
357
|
+
self.Conv2d_0 = conv
|
|
358
|
+
|
|
359
|
+
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
|
|
360
|
+
assert hidden_states.shape[1] == self.channels
|
|
361
|
+
|
|
362
|
+
if self.norm is not None:
|
|
363
|
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
364
|
+
|
|
365
|
+
if self.use_conv_transpose:
|
|
366
|
+
return self.conv(hidden_states)
|
|
367
|
+
|
|
368
|
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
|
|
369
|
+
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
|
|
370
|
+
dtype = hidden_states.dtype
|
|
371
|
+
if dtype == torch.bfloat16:
|
|
372
|
+
hidden_states = hidden_states.to(torch.float32)
|
|
373
|
+
|
|
374
|
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
|
375
|
+
if hidden_states.shape[0] >= 64:
|
|
376
|
+
hidden_states = hidden_states.contiguous()
|
|
377
|
+
|
|
378
|
+
# if `output_size` is passed we force the interpolation output
|
|
379
|
+
# size and do not make use of `scale_factor=2`
|
|
380
|
+
if self.interpolate:
|
|
381
|
+
# upsample_nearest_nhwc also fails when the number of output elements is large
|
|
382
|
+
# https://github.com/pytorch/pytorch/issues/141831
|
|
383
|
+
scale_factor = (
|
|
384
|
+
2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])
|
|
385
|
+
)
|
|
386
|
+
if hidden_states.numel() * scale_factor > pow(2, 31):
|
|
387
|
+
hidden_states = hidden_states.contiguous()
|
|
388
|
+
|
|
389
|
+
if output_size is None:
|
|
390
|
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
|
391
|
+
else:
|
|
392
|
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
|
393
|
+
|
|
394
|
+
# Cast back to original dtype
|
|
395
|
+
if dtype == torch.bfloat16:
|
|
396
|
+
hidden_states = hidden_states.to(dtype)
|
|
397
|
+
|
|
398
|
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
399
|
+
if self.use_conv:
|
|
400
|
+
if self.name == "conv":
|
|
401
|
+
hidden_states = self.conv(hidden_states)
|
|
402
|
+
else:
|
|
403
|
+
hidden_states = self.Conv2d_0(hidden_states)
|
|
404
|
+
|
|
405
|
+
return hidden_states
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class Attention(nn.Module):
|
|
409
|
+
r"""
|
|
410
|
+
A cross attention layer.
|
|
411
|
+
|
|
412
|
+
Parameters:
|
|
413
|
+
query_dim (`int`):
|
|
414
|
+
The number of channels in the query.
|
|
415
|
+
cross_attention_dim (`int`, *optional*):
|
|
416
|
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
|
417
|
+
heads (`int`, *optional*, defaults to 8):
|
|
418
|
+
The number of heads to use for multi-head attention.
|
|
419
|
+
kv_heads (`int`, *optional*, defaults to `None`):
|
|
420
|
+
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
|
421
|
+
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
|
422
|
+
Query Attention (MQA) otherwise GQA is used.
|
|
423
|
+
dim_head (`int`, *optional*, defaults to 64):
|
|
424
|
+
The number of channels in each head.
|
|
425
|
+
dropout (`float`, *optional*, defaults to 0.0):
|
|
426
|
+
The dropout probability to use.
|
|
427
|
+
bias (`bool`, *optional*, defaults to False):
|
|
428
|
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
|
429
|
+
upcast_attention (`bool`, *optional*, defaults to False):
|
|
430
|
+
Set to `True` to upcast the attention computation to `float32`.
|
|
431
|
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
|
432
|
+
Set to `True` to upcast the softmax computation to `float32`.
|
|
433
|
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
|
434
|
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
|
435
|
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
|
436
|
+
The number of groups to use for the group norm in the cross attention.
|
|
437
|
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
|
438
|
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
|
439
|
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
|
440
|
+
The number of groups to use for the group norm in the attention.
|
|
441
|
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
|
442
|
+
The number of channels to use for the spatial normalization.
|
|
443
|
+
out_bias (`bool`, *optional*, defaults to `True`):
|
|
444
|
+
Set to `True` to use a bias in the output linear layer.
|
|
445
|
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
|
446
|
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
|
447
|
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
|
448
|
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
|
449
|
+
`added_kv_proj_dim` is not `None`.
|
|
450
|
+
eps (`float`, *optional*, defaults to 1e-5):
|
|
451
|
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
|
452
|
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
|
453
|
+
A factor to rescale the output by dividing it with this value.
|
|
454
|
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
|
455
|
+
Set to `True` to add the residual connection to the output.
|
|
456
|
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
|
457
|
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
|
458
|
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
|
459
|
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
|
460
|
+
`AttnProcessor` otherwise.
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
def __init__(
|
|
464
|
+
self,
|
|
465
|
+
query_dim: int,
|
|
466
|
+
cross_attention_dim: Optional[int] = None,
|
|
467
|
+
heads: int = 8,
|
|
468
|
+
kv_heads: Optional[int] = None,
|
|
469
|
+
dim_head: int = 64,
|
|
470
|
+
dropout: float = 0.0,
|
|
471
|
+
bias: bool = False,
|
|
472
|
+
upcast_attention: bool = False,
|
|
473
|
+
upcast_softmax: bool = False,
|
|
474
|
+
cross_attention_norm: Optional[str] = None,
|
|
475
|
+
cross_attention_norm_num_groups: int = 32,
|
|
476
|
+
qk_norm: Optional[str] = None,
|
|
477
|
+
added_kv_proj_dim: Optional[int] = None,
|
|
478
|
+
added_proj_bias: Optional[bool] = True,
|
|
479
|
+
norm_num_groups: Optional[int] = None,
|
|
480
|
+
spatial_norm_dim: Optional[int] = None,
|
|
481
|
+
out_bias: bool = True,
|
|
482
|
+
scale_qk: bool = True,
|
|
483
|
+
only_cross_attention: bool = False,
|
|
484
|
+
eps: float = 1e-5,
|
|
485
|
+
rescale_output_factor: float = 1.0,
|
|
486
|
+
residual_connection: bool = False,
|
|
487
|
+
_from_deprecated_attn_block: bool = False,
|
|
488
|
+
processor: Optional["AttnProcessor"] = None,
|
|
489
|
+
out_dim: int = None,
|
|
490
|
+
out_context_dim: int = None,
|
|
491
|
+
context_pre_only=None,
|
|
492
|
+
pre_only=False,
|
|
493
|
+
elementwise_affine: bool = True,
|
|
494
|
+
is_causal: bool = False,
|
|
495
|
+
):
|
|
496
|
+
super().__init__()
|
|
497
|
+
|
|
498
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
499
|
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
|
500
|
+
self.query_dim = query_dim
|
|
501
|
+
self.use_bias = bias
|
|
502
|
+
self.is_cross_attention = cross_attention_dim is not None
|
|
503
|
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
504
|
+
self.upcast_attention = upcast_attention
|
|
505
|
+
self.upcast_softmax = upcast_softmax
|
|
506
|
+
self.rescale_output_factor = rescale_output_factor
|
|
507
|
+
self.residual_connection = residual_connection
|
|
508
|
+
self.dropout = dropout
|
|
509
|
+
self.fused_projections = False
|
|
510
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
511
|
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
|
512
|
+
self.context_pre_only = context_pre_only
|
|
513
|
+
self.pre_only = pre_only
|
|
514
|
+
self.is_causal = is_causal
|
|
515
|
+
|
|
516
|
+
# we make use of this private variable to know whether this class is loaded
|
|
517
|
+
# with an deprecated state dict so that we can convert it on the fly
|
|
518
|
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
|
519
|
+
|
|
520
|
+
self.scale_qk = scale_qk
|
|
521
|
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
|
522
|
+
|
|
523
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
|
524
|
+
# for slice_size > 0 the attention score computation
|
|
525
|
+
# is split across the batch axis to save memory
|
|
526
|
+
# You can set slice_size with `set_attention_slice`
|
|
527
|
+
self.sliceable_head_dim = heads
|
|
528
|
+
|
|
529
|
+
self.added_kv_proj_dim = added_kv_proj_dim
|
|
530
|
+
self.only_cross_attention = only_cross_attention
|
|
531
|
+
|
|
532
|
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
|
533
|
+
raise ValueError(
|
|
534
|
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if norm_num_groups is not None:
|
|
538
|
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
|
539
|
+
else:
|
|
540
|
+
self.group_norm = None
|
|
541
|
+
|
|
542
|
+
if spatial_norm_dim is not None:
|
|
543
|
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
|
544
|
+
else:
|
|
545
|
+
self.spatial_norm = None
|
|
546
|
+
|
|
547
|
+
if qk_norm is None:
|
|
548
|
+
self.norm_q = None
|
|
549
|
+
self.norm_k = None
|
|
550
|
+
elif qk_norm == "layer_norm":
|
|
551
|
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
552
|
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
553
|
+
elif qk_norm == "fp32_layer_norm":
|
|
554
|
+
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
|
555
|
+
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
|
556
|
+
elif qk_norm == "layer_norm_across_heads":
|
|
557
|
+
# Lumina applies qk norm across all heads
|
|
558
|
+
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
|
559
|
+
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
|
560
|
+
elif qk_norm == "rms_norm":
|
|
561
|
+
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
562
|
+
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
563
|
+
elif qk_norm == "rms_norm_across_heads":
|
|
564
|
+
# LTX applies qk norm across all heads
|
|
565
|
+
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
|
566
|
+
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
|
567
|
+
elif qk_norm == "l2":
|
|
568
|
+
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
|
569
|
+
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
|
570
|
+
else:
|
|
571
|
+
raise ValueError(
|
|
572
|
+
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
if cross_attention_norm is None:
|
|
576
|
+
self.norm_cross = None
|
|
577
|
+
elif cross_attention_norm == "layer_norm":
|
|
578
|
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
|
579
|
+
elif cross_attention_norm == "group_norm":
|
|
580
|
+
if self.added_kv_proj_dim is not None:
|
|
581
|
+
# The given `encoder_hidden_states` are initially of shape
|
|
582
|
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
|
583
|
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
|
584
|
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
|
585
|
+
# the number of channels for the group norm.
|
|
586
|
+
norm_cross_num_channels = added_kv_proj_dim
|
|
587
|
+
else:
|
|
588
|
+
norm_cross_num_channels = self.cross_attention_dim
|
|
589
|
+
|
|
590
|
+
self.norm_cross = nn.GroupNorm(
|
|
591
|
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
|
592
|
+
)
|
|
593
|
+
else:
|
|
594
|
+
raise ValueError(
|
|
595
|
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
599
|
+
|
|
600
|
+
if not self.only_cross_attention:
|
|
601
|
+
# only relevant for the `AddedKVProcessor` classes
|
|
602
|
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
|
603
|
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
|
604
|
+
else:
|
|
605
|
+
self.to_k = None
|
|
606
|
+
self.to_v = None
|
|
607
|
+
|
|
608
|
+
self.added_proj_bias = added_proj_bias
|
|
609
|
+
if self.added_kv_proj_dim is not None:
|
|
610
|
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
|
611
|
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
|
612
|
+
if self.context_pre_only is not None:
|
|
613
|
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
|
614
|
+
else:
|
|
615
|
+
self.add_q_proj = None
|
|
616
|
+
self.add_k_proj = None
|
|
617
|
+
self.add_v_proj = None
|
|
618
|
+
|
|
619
|
+
if not self.pre_only:
|
|
620
|
+
self.to_out = nn.ModuleList([])
|
|
621
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
|
622
|
+
self.to_out.append(nn.Dropout(dropout))
|
|
623
|
+
else:
|
|
624
|
+
self.to_out = None
|
|
625
|
+
|
|
626
|
+
if self.context_pre_only is not None and not self.context_pre_only:
|
|
627
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
|
628
|
+
else:
|
|
629
|
+
self.to_add_out = None
|
|
630
|
+
|
|
631
|
+
if qk_norm is not None and added_kv_proj_dim is not None:
|
|
632
|
+
if qk_norm == "layer_norm":
|
|
633
|
+
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
634
|
+
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
|
635
|
+
elif qk_norm == "fp32_layer_norm":
|
|
636
|
+
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
|
637
|
+
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
|
638
|
+
elif qk_norm == "rms_norm":
|
|
639
|
+
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
|
640
|
+
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
|
641
|
+
elif qk_norm == "rms_norm_across_heads":
|
|
642
|
+
# Wan applies qk norm across all heads
|
|
643
|
+
# Wan also doesn't apply a q norm
|
|
644
|
+
self.norm_added_q = None
|
|
645
|
+
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
|
646
|
+
else:
|
|
647
|
+
raise ValueError(
|
|
648
|
+
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
|
649
|
+
)
|
|
650
|
+
else:
|
|
651
|
+
self.norm_added_q = None
|
|
652
|
+
self.norm_added_k = None
|
|
653
|
+
|
|
654
|
+
# set attention processor
|
|
655
|
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
|
656
|
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
|
657
|
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
|
658
|
+
if processor is None:
|
|
659
|
+
processor = (
|
|
660
|
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
|
661
|
+
)
|
|
662
|
+
self.set_processor(processor)
|
|
663
|
+
|
|
664
|
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
|
665
|
+
r"""
|
|
666
|
+
Set the attention processor to use.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
processor (`AttnProcessor`):
|
|
670
|
+
The attention processor to use.
|
|
671
|
+
"""
|
|
672
|
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
|
673
|
+
# pop `processor` from `self._modules`
|
|
674
|
+
if (
|
|
675
|
+
hasattr(self, "processor")
|
|
676
|
+
and isinstance(self.processor, torch.nn.Module)
|
|
677
|
+
and not isinstance(processor, torch.nn.Module)
|
|
678
|
+
):
|
|
679
|
+
self._modules.pop("processor")
|
|
680
|
+
|
|
681
|
+
self.processor = processor
|
|
682
|
+
|
|
683
|
+
def forward(
|
|
684
|
+
self,
|
|
685
|
+
hidden_states: torch.Tensor,
|
|
686
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
687
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
688
|
+
**cross_attention_kwargs,
|
|
689
|
+
) -> torch.Tensor:
|
|
690
|
+
r"""
|
|
691
|
+
The forward method of the `Attention` class.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
hidden_states (`torch.Tensor`):
|
|
695
|
+
The hidden states of the query.
|
|
696
|
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
697
|
+
The hidden states of the encoder.
|
|
698
|
+
attention_mask (`torch.Tensor`, *optional*):
|
|
699
|
+
The attention mask to use. If `None`, no mask is applied.
|
|
700
|
+
**cross_attention_kwargs:
|
|
701
|
+
Additional keyword arguments to pass along to the cross attention.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
`torch.Tensor`: The output of the attention layer.
|
|
705
|
+
"""
|
|
706
|
+
# The `Attention` class can call different attention processors / attention functions
|
|
707
|
+
# here we simply pass along all tensors to the selected processor class
|
|
708
|
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
|
709
|
+
|
|
710
|
+
return self.processor(
|
|
711
|
+
self,
|
|
712
|
+
hidden_states,
|
|
713
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
714
|
+
attention_mask=attention_mask,
|
|
715
|
+
**cross_attention_kwargs,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
719
|
+
r"""
|
|
720
|
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
|
721
|
+
is the number of heads initialized while constructing the `Attention` class.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
`torch.Tensor`: The reshaped tensor.
|
|
728
|
+
"""
|
|
729
|
+
head_size = self.heads
|
|
730
|
+
batch_size, seq_len, dim = tensor.shape
|
|
731
|
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
732
|
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
733
|
+
return tensor
|
|
734
|
+
|
|
735
|
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
|
736
|
+
r"""
|
|
737
|
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
|
738
|
+
the number of heads initialized while constructing the `Attention` class.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
|
742
|
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
|
743
|
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
`torch.Tensor`: The reshaped tensor.
|
|
747
|
+
"""
|
|
748
|
+
head_size = self.heads
|
|
749
|
+
if tensor.ndim == 3:
|
|
750
|
+
batch_size, seq_len, dim = tensor.shape
|
|
751
|
+
extra_dim = 1
|
|
752
|
+
else:
|
|
753
|
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
|
754
|
+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
|
755
|
+
tensor = tensor.permute(0, 2, 1, 3)
|
|
756
|
+
|
|
757
|
+
if out_dim == 3:
|
|
758
|
+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
|
759
|
+
|
|
760
|
+
return tensor
|
|
761
|
+
|
|
762
|
+
def get_attention_scores(
|
|
763
|
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
|
764
|
+
) -> torch.Tensor:
|
|
765
|
+
r"""
|
|
766
|
+
Compute the attention scores.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
query (`torch.Tensor`): The query tensor.
|
|
770
|
+
key (`torch.Tensor`): The key tensor.
|
|
771
|
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
`torch.Tensor`: The attention probabilities/scores.
|
|
775
|
+
"""
|
|
776
|
+
dtype = query.dtype
|
|
777
|
+
if self.upcast_attention:
|
|
778
|
+
query = query.float()
|
|
779
|
+
key = key.float()
|
|
780
|
+
|
|
781
|
+
if attention_mask is None:
|
|
782
|
+
baddbmm_input = torch.empty(
|
|
783
|
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
|
784
|
+
)
|
|
785
|
+
beta = 0
|
|
786
|
+
else:
|
|
787
|
+
baddbmm_input = attention_mask
|
|
788
|
+
beta = 1
|
|
789
|
+
|
|
790
|
+
attention_scores = torch.baddbmm(
|
|
791
|
+
baddbmm_input,
|
|
792
|
+
query,
|
|
793
|
+
key.transpose(-1, -2),
|
|
794
|
+
beta=beta,
|
|
795
|
+
alpha=self.scale,
|
|
796
|
+
)
|
|
797
|
+
del baddbmm_input
|
|
798
|
+
|
|
799
|
+
if self.upcast_softmax:
|
|
800
|
+
attention_scores = attention_scores.float()
|
|
801
|
+
|
|
802
|
+
attention_probs = attention_scores.softmax(dim=-1)
|
|
803
|
+
del attention_scores
|
|
804
|
+
|
|
805
|
+
attention_probs = attention_probs.to(dtype)
|
|
806
|
+
|
|
807
|
+
return attention_probs
|
|
808
|
+
|
|
809
|
+
def prepare_attention_mask(
|
|
810
|
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
|
811
|
+
) -> torch.Tensor:
|
|
812
|
+
r"""
|
|
813
|
+
Prepare the attention mask for the attention computation.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
attention_mask (`torch.Tensor`):
|
|
817
|
+
The attention mask to prepare.
|
|
818
|
+
target_length (`int`):
|
|
819
|
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
|
820
|
+
batch_size (`int`):
|
|
821
|
+
The batch size, which is used to repeat the attention mask.
|
|
822
|
+
out_dim (`int`, *optional*, defaults to `3`):
|
|
823
|
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
`torch.Tensor`: The prepared attention mask.
|
|
827
|
+
"""
|
|
828
|
+
head_size = self.heads
|
|
829
|
+
if attention_mask is None:
|
|
830
|
+
return attention_mask
|
|
831
|
+
|
|
832
|
+
current_length: int = attention_mask.shape[-1]
|
|
833
|
+
if current_length != target_length:
|
|
834
|
+
if attention_mask.device.type == "mps":
|
|
835
|
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
|
836
|
+
# Instead, we can manually construct the padding tensor.
|
|
837
|
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
|
838
|
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
|
839
|
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
|
840
|
+
else:
|
|
841
|
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
842
|
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
|
843
|
+
# remaining_length: int = target_length - current_length
|
|
844
|
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
|
845
|
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
|
846
|
+
|
|
847
|
+
if out_dim == 3:
|
|
848
|
+
if attention_mask.shape[0] < batch_size * head_size:
|
|
849
|
+
attention_mask = attention_mask.repeat_interleave(
|
|
850
|
+
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
|
|
851
|
+
)
|
|
852
|
+
elif out_dim == 4:
|
|
853
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
854
|
+
attention_mask = attention_mask.repeat_interleave(
|
|
855
|
+
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return attention_mask
|
|
859
|
+
|
|
860
|
+
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
|
861
|
+
r"""
|
|
862
|
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
|
863
|
+
`Attention` class.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
|
867
|
+
|
|
868
|
+
Returns:
|
|
869
|
+
`torch.Tensor`: The normalized encoder hidden states.
|
|
870
|
+
"""
|
|
871
|
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
|
872
|
+
|
|
873
|
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
|
874
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
875
|
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
|
876
|
+
# Group norm norms along the channels dimension and expects
|
|
877
|
+
# input to be in the shape of (N, C, *). In this case, we want
|
|
878
|
+
# to norm along the hidden dimension, so we need to move
|
|
879
|
+
# (batch_size, sequence_length, hidden_size) ->
|
|
880
|
+
# (batch_size, hidden_size, sequence_length)
|
|
881
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
882
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
883
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
884
|
+
else:
|
|
885
|
+
assert False
|
|
886
|
+
|
|
887
|
+
return encoder_hidden_states
|
|
888
|
+
|
|
889
|
+
@torch.no_grad()
|
|
890
|
+
def fuse_projections(self, fuse=True):
|
|
891
|
+
device = self.to_q.weight.data.device
|
|
892
|
+
dtype = self.to_q.weight.data.dtype
|
|
893
|
+
|
|
894
|
+
if not self.is_cross_attention:
|
|
895
|
+
# fetch weight matrices.
|
|
896
|
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
|
897
|
+
in_features = concatenated_weights.shape[1]
|
|
898
|
+
out_features = concatenated_weights.shape[0]
|
|
899
|
+
|
|
900
|
+
# create a new single projection layer and copy over the weights.
|
|
901
|
+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
|
902
|
+
self.to_qkv.weight.copy_(concatenated_weights)
|
|
903
|
+
if self.use_bias:
|
|
904
|
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
|
905
|
+
self.to_qkv.bias.copy_(concatenated_bias)
|
|
906
|
+
|
|
907
|
+
else:
|
|
908
|
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
|
909
|
+
in_features = concatenated_weights.shape[1]
|
|
910
|
+
out_features = concatenated_weights.shape[0]
|
|
911
|
+
|
|
912
|
+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
|
913
|
+
self.to_kv.weight.copy_(concatenated_weights)
|
|
914
|
+
if self.use_bias:
|
|
915
|
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
|
916
|
+
self.to_kv.bias.copy_(concatenated_bias)
|
|
917
|
+
|
|
918
|
+
# handle added projections for SD3 and others.
|
|
919
|
+
if (
|
|
920
|
+
getattr(self, "add_q_proj", None) is not None
|
|
921
|
+
and getattr(self, "add_k_proj", None) is not None
|
|
922
|
+
and getattr(self, "add_v_proj", None) is not None
|
|
923
|
+
):
|
|
924
|
+
concatenated_weights = torch.cat(
|
|
925
|
+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
|
926
|
+
)
|
|
927
|
+
in_features = concatenated_weights.shape[1]
|
|
928
|
+
out_features = concatenated_weights.shape[0]
|
|
929
|
+
|
|
930
|
+
self.to_added_qkv = nn.Linear(
|
|
931
|
+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
|
932
|
+
)
|
|
933
|
+
self.to_added_qkv.weight.copy_(concatenated_weights)
|
|
934
|
+
if self.added_proj_bias:
|
|
935
|
+
concatenated_bias = torch.cat(
|
|
936
|
+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
|
937
|
+
)
|
|
938
|
+
self.to_added_qkv.bias.copy_(concatenated_bias)
|
|
939
|
+
|
|
940
|
+
self.fused_projections = fuse
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
class AttnProcessor2_0:
|
|
944
|
+
r"""
|
|
945
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
946
|
+
"""
|
|
947
|
+
|
|
948
|
+
def __init__(self):
|
|
949
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
|
950
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
951
|
+
|
|
952
|
+
def __call__(
|
|
953
|
+
self,
|
|
954
|
+
attn: Attention,
|
|
955
|
+
hidden_states: torch.Tensor,
|
|
956
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
957
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
958
|
+
temb: Optional[torch.Tensor] = None,
|
|
959
|
+
*args,
|
|
960
|
+
**kwargs,
|
|
961
|
+
) -> torch.Tensor:
|
|
962
|
+
residual = hidden_states
|
|
963
|
+
if attn.spatial_norm is not None:
|
|
964
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
965
|
+
|
|
966
|
+
input_ndim = hidden_states.ndim
|
|
967
|
+
|
|
968
|
+
if input_ndim == 4:
|
|
969
|
+
batch_size, channel, height, width = hidden_states.shape
|
|
970
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
971
|
+
|
|
972
|
+
batch_size, sequence_length, _ = (
|
|
973
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
if attention_mask is not None:
|
|
977
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
978
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
|
979
|
+
# (batch, heads, source_length, target_length)
|
|
980
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
981
|
+
|
|
982
|
+
if attn.group_norm is not None:
|
|
983
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
984
|
+
|
|
985
|
+
query = attn.to_q(hidden_states)
|
|
986
|
+
|
|
987
|
+
if encoder_hidden_states is None:
|
|
988
|
+
encoder_hidden_states = hidden_states
|
|
989
|
+
elif attn.norm_cross:
|
|
990
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
991
|
+
|
|
992
|
+
key = attn.to_k(encoder_hidden_states)
|
|
993
|
+
value = attn.to_v(encoder_hidden_states)
|
|
994
|
+
|
|
995
|
+
inner_dim = key.shape[-1]
|
|
996
|
+
head_dim = inner_dim // attn.heads
|
|
997
|
+
|
|
998
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
999
|
+
|
|
1000
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
1001
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
1002
|
+
|
|
1003
|
+
if attn.norm_q is not None:
|
|
1004
|
+
query = attn.norm_q(query)
|
|
1005
|
+
if attn.norm_k is not None:
|
|
1006
|
+
key = attn.norm_k(key)
|
|
1007
|
+
|
|
1008
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
1009
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
1010
|
+
hidden_states = F.scaled_dot_product_attention(
|
|
1011
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
1015
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
1016
|
+
|
|
1017
|
+
# linear proj
|
|
1018
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
1019
|
+
# dropout
|
|
1020
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
1021
|
+
|
|
1022
|
+
if input_ndim == 4:
|
|
1023
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
1024
|
+
|
|
1025
|
+
if attn.residual_connection:
|
|
1026
|
+
hidden_states = hidden_states + residual
|
|
1027
|
+
|
|
1028
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
|
1029
|
+
|
|
1030
|
+
return hidden_states
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
class UNetMidBlock2D(nn.Module):
|
|
1034
|
+
"""
|
|
1035
|
+
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
|
|
1036
|
+
|
|
1037
|
+
Args:
|
|
1038
|
+
in_channels (`int`): The number of input channels.
|
|
1039
|
+
temb_channels (`int`): The number of temporal embedding channels.
|
|
1040
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
|
1041
|
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
|
1042
|
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
|
1043
|
+
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
|
1044
|
+
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
|
1045
|
+
model on tasks with long-range temporal dependencies.
|
|
1046
|
+
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
|
1047
|
+
resnet_groups (`int`, *optional*, defaults to 32):
|
|
1048
|
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
|
1049
|
+
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
|
1050
|
+
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
|
1051
|
+
Whether to use pre-normalization for the resnet blocks.
|
|
1052
|
+
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
|
1053
|
+
attention_head_dim (`int`, *optional*, defaults to 1):
|
|
1054
|
+
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
|
1055
|
+
the number of input channels.
|
|
1056
|
+
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
|
1057
|
+
|
|
1058
|
+
Returns:
|
|
1059
|
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
|
|
1060
|
+
height, width)`.
|
|
1061
|
+
|
|
1062
|
+
"""
|
|
1063
|
+
|
|
1064
|
+
def __init__(
|
|
1065
|
+
self,
|
|
1066
|
+
in_channels: int,
|
|
1067
|
+
temb_channels: int,
|
|
1068
|
+
dropout: float = 0.0,
|
|
1069
|
+
num_layers: int = 1,
|
|
1070
|
+
resnet_eps: float = 1e-6,
|
|
1071
|
+
resnet_time_scale_shift: str = "default", # default, spatial
|
|
1072
|
+
resnet_act_fn: str = "swish",
|
|
1073
|
+
resnet_groups: int = 32,
|
|
1074
|
+
attn_groups: Optional[int] = None,
|
|
1075
|
+
resnet_pre_norm: bool = True,
|
|
1076
|
+
add_attention: bool = True,
|
|
1077
|
+
attention_head_dim: int = 1,
|
|
1078
|
+
output_scale_factor: float = 1.0,
|
|
1079
|
+
):
|
|
1080
|
+
super().__init__()
|
|
1081
|
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
|
1082
|
+
self.add_attention = add_attention
|
|
1083
|
+
|
|
1084
|
+
if attn_groups is None:
|
|
1085
|
+
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
|
1086
|
+
|
|
1087
|
+
# there is always at least one resnet
|
|
1088
|
+
if resnet_time_scale_shift == "spatial":
|
|
1089
|
+
resnets = [
|
|
1090
|
+
ResnetBlockCondNorm2D(
|
|
1091
|
+
in_channels=in_channels,
|
|
1092
|
+
out_channels=in_channels,
|
|
1093
|
+
temb_channels=temb_channels,
|
|
1094
|
+
eps=resnet_eps,
|
|
1095
|
+
groups=resnet_groups,
|
|
1096
|
+
dropout=dropout,
|
|
1097
|
+
time_embedding_norm="spatial",
|
|
1098
|
+
non_linearity=resnet_act_fn,
|
|
1099
|
+
output_scale_factor=output_scale_factor,
|
|
1100
|
+
)
|
|
1101
|
+
]
|
|
1102
|
+
else:
|
|
1103
|
+
resnets = [
|
|
1104
|
+
ResnetBlock2D(
|
|
1105
|
+
in_channels=in_channels,
|
|
1106
|
+
out_channels=in_channels,
|
|
1107
|
+
temb_channels=temb_channels,
|
|
1108
|
+
eps=resnet_eps,
|
|
1109
|
+
groups=resnet_groups,
|
|
1110
|
+
dropout=dropout,
|
|
1111
|
+
time_embedding_norm=resnet_time_scale_shift,
|
|
1112
|
+
non_linearity=resnet_act_fn,
|
|
1113
|
+
output_scale_factor=output_scale_factor,
|
|
1114
|
+
pre_norm=resnet_pre_norm,
|
|
1115
|
+
)
|
|
1116
|
+
]
|
|
1117
|
+
attentions = []
|
|
1118
|
+
|
|
1119
|
+
if attention_head_dim is None:
|
|
1120
|
+
attention_head_dim = in_channels
|
|
1121
|
+
|
|
1122
|
+
for _ in range(num_layers):
|
|
1123
|
+
if self.add_attention:
|
|
1124
|
+
attentions.append(
|
|
1125
|
+
Attention(
|
|
1126
|
+
in_channels,
|
|
1127
|
+
heads=in_channels // attention_head_dim,
|
|
1128
|
+
dim_head=attention_head_dim,
|
|
1129
|
+
rescale_output_factor=output_scale_factor,
|
|
1130
|
+
eps=resnet_eps,
|
|
1131
|
+
norm_num_groups=attn_groups,
|
|
1132
|
+
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
|
1133
|
+
residual_connection=True,
|
|
1134
|
+
bias=True,
|
|
1135
|
+
upcast_softmax=True,
|
|
1136
|
+
_from_deprecated_attn_block=True,
|
|
1137
|
+
)
|
|
1138
|
+
)
|
|
1139
|
+
else:
|
|
1140
|
+
attentions.append(None)
|
|
1141
|
+
|
|
1142
|
+
if resnet_time_scale_shift == "spatial":
|
|
1143
|
+
resnets.append(
|
|
1144
|
+
ResnetBlockCondNorm2D(
|
|
1145
|
+
in_channels=in_channels,
|
|
1146
|
+
out_channels=in_channels,
|
|
1147
|
+
temb_channels=temb_channels,
|
|
1148
|
+
eps=resnet_eps,
|
|
1149
|
+
groups=resnet_groups,
|
|
1150
|
+
dropout=dropout,
|
|
1151
|
+
time_embedding_norm="spatial",
|
|
1152
|
+
non_linearity=resnet_act_fn,
|
|
1153
|
+
output_scale_factor=output_scale_factor,
|
|
1154
|
+
)
|
|
1155
|
+
)
|
|
1156
|
+
else:
|
|
1157
|
+
resnets.append(
|
|
1158
|
+
ResnetBlock2D(
|
|
1159
|
+
in_channels=in_channels,
|
|
1160
|
+
out_channels=in_channels,
|
|
1161
|
+
temb_channels=temb_channels,
|
|
1162
|
+
eps=resnet_eps,
|
|
1163
|
+
groups=resnet_groups,
|
|
1164
|
+
dropout=dropout,
|
|
1165
|
+
time_embedding_norm=resnet_time_scale_shift,
|
|
1166
|
+
non_linearity=resnet_act_fn,
|
|
1167
|
+
output_scale_factor=output_scale_factor,
|
|
1168
|
+
pre_norm=resnet_pre_norm,
|
|
1169
|
+
)
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
self.attentions = nn.ModuleList(attentions)
|
|
1173
|
+
self.resnets = nn.ModuleList(resnets)
|
|
1174
|
+
|
|
1175
|
+
self.gradient_checkpointing = False
|
|
1176
|
+
|
|
1177
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1178
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
|
1179
|
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
1180
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
1181
|
+
if attn is not None:
|
|
1182
|
+
hidden_states = attn(hidden_states, temb=temb)
|
|
1183
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
|
1184
|
+
else:
|
|
1185
|
+
if attn is not None:
|
|
1186
|
+
hidden_states = attn(hidden_states, temb=temb)
|
|
1187
|
+
hidden_states = resnet(hidden_states, temb)
|
|
1188
|
+
|
|
1189
|
+
return hidden_states
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
class DownEncoderBlock2D(nn.Module):
|
|
1193
|
+
def __init__(
|
|
1194
|
+
self,
|
|
1195
|
+
in_channels: int,
|
|
1196
|
+
out_channels: int,
|
|
1197
|
+
dropout: float = 0.0,
|
|
1198
|
+
num_layers: int = 1,
|
|
1199
|
+
resnet_eps: float = 1e-6,
|
|
1200
|
+
resnet_time_scale_shift: str = "default",
|
|
1201
|
+
resnet_act_fn: str = "swish",
|
|
1202
|
+
resnet_groups: int = 32,
|
|
1203
|
+
resnet_pre_norm: bool = True,
|
|
1204
|
+
output_scale_factor: float = 1.0,
|
|
1205
|
+
add_downsample: bool = True,
|
|
1206
|
+
downsample_padding: int = 1,
|
|
1207
|
+
):
|
|
1208
|
+
super().__init__()
|
|
1209
|
+
resnets = []
|
|
1210
|
+
|
|
1211
|
+
for i in range(num_layers):
|
|
1212
|
+
in_channels = in_channels if i == 0 else out_channels
|
|
1213
|
+
if resnet_time_scale_shift == "spatial":
|
|
1214
|
+
resnets.append(
|
|
1215
|
+
ResnetBlockCondNorm2D(
|
|
1216
|
+
in_channels=in_channels,
|
|
1217
|
+
out_channels=out_channels,
|
|
1218
|
+
temb_channels=None,
|
|
1219
|
+
eps=resnet_eps,
|
|
1220
|
+
groups=resnet_groups,
|
|
1221
|
+
dropout=dropout,
|
|
1222
|
+
time_embedding_norm="spatial",
|
|
1223
|
+
non_linearity=resnet_act_fn,
|
|
1224
|
+
output_scale_factor=output_scale_factor,
|
|
1225
|
+
)
|
|
1226
|
+
)
|
|
1227
|
+
else:
|
|
1228
|
+
resnets.append(
|
|
1229
|
+
ResnetBlock2D(
|
|
1230
|
+
in_channels=in_channels,
|
|
1231
|
+
out_channels=out_channels,
|
|
1232
|
+
temb_channels=None,
|
|
1233
|
+
eps=resnet_eps,
|
|
1234
|
+
groups=resnet_groups,
|
|
1235
|
+
dropout=dropout,
|
|
1236
|
+
time_embedding_norm=resnet_time_scale_shift,
|
|
1237
|
+
non_linearity=resnet_act_fn,
|
|
1238
|
+
output_scale_factor=output_scale_factor,
|
|
1239
|
+
pre_norm=resnet_pre_norm,
|
|
1240
|
+
)
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
self.resnets = nn.ModuleList(resnets)
|
|
1244
|
+
|
|
1245
|
+
if add_downsample:
|
|
1246
|
+
self.downsamplers = nn.ModuleList(
|
|
1247
|
+
[
|
|
1248
|
+
Downsample2D(
|
|
1249
|
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
|
1250
|
+
)
|
|
1251
|
+
]
|
|
1252
|
+
)
|
|
1253
|
+
else:
|
|
1254
|
+
self.downsamplers = None
|
|
1255
|
+
|
|
1256
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
1257
|
+
for resnet in self.resnets:
|
|
1258
|
+
hidden_states = resnet(hidden_states, temb=None)
|
|
1259
|
+
|
|
1260
|
+
if self.downsamplers is not None:
|
|
1261
|
+
for downsampler in self.downsamplers:
|
|
1262
|
+
hidden_states = downsampler(hidden_states)
|
|
1263
|
+
|
|
1264
|
+
return hidden_states
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
class UpDecoderBlock2D(nn.Module):
|
|
1268
|
+
def __init__(
|
|
1269
|
+
self,
|
|
1270
|
+
in_channels: int,
|
|
1271
|
+
out_channels: int,
|
|
1272
|
+
resolution_idx: Optional[int] = None,
|
|
1273
|
+
dropout: float = 0.0,
|
|
1274
|
+
num_layers: int = 1,
|
|
1275
|
+
resnet_eps: float = 1e-6,
|
|
1276
|
+
resnet_time_scale_shift: str = "default", # default, spatial
|
|
1277
|
+
resnet_act_fn: str = "swish",
|
|
1278
|
+
resnet_groups: int = 32,
|
|
1279
|
+
resnet_pre_norm: bool = True,
|
|
1280
|
+
output_scale_factor: float = 1.0,
|
|
1281
|
+
add_upsample: bool = True,
|
|
1282
|
+
temb_channels: Optional[int] = None,
|
|
1283
|
+
):
|
|
1284
|
+
super().__init__()
|
|
1285
|
+
resnets = []
|
|
1286
|
+
|
|
1287
|
+
for i in range(num_layers):
|
|
1288
|
+
input_channels = in_channels if i == 0 else out_channels
|
|
1289
|
+
|
|
1290
|
+
if resnet_time_scale_shift == "spatial":
|
|
1291
|
+
resnets.append(
|
|
1292
|
+
ResnetBlockCondNorm2D(
|
|
1293
|
+
in_channels=input_channels,
|
|
1294
|
+
out_channels=out_channels,
|
|
1295
|
+
temb_channels=temb_channels,
|
|
1296
|
+
eps=resnet_eps,
|
|
1297
|
+
groups=resnet_groups,
|
|
1298
|
+
dropout=dropout,
|
|
1299
|
+
time_embedding_norm="spatial",
|
|
1300
|
+
non_linearity=resnet_act_fn,
|
|
1301
|
+
output_scale_factor=output_scale_factor,
|
|
1302
|
+
)
|
|
1303
|
+
)
|
|
1304
|
+
else:
|
|
1305
|
+
resnets.append(
|
|
1306
|
+
ResnetBlock2D(
|
|
1307
|
+
in_channels=input_channels,
|
|
1308
|
+
out_channels=out_channels,
|
|
1309
|
+
temb_channels=temb_channels,
|
|
1310
|
+
eps=resnet_eps,
|
|
1311
|
+
groups=resnet_groups,
|
|
1312
|
+
dropout=dropout,
|
|
1313
|
+
time_embedding_norm=resnet_time_scale_shift,
|
|
1314
|
+
non_linearity=resnet_act_fn,
|
|
1315
|
+
output_scale_factor=output_scale_factor,
|
|
1316
|
+
pre_norm=resnet_pre_norm,
|
|
1317
|
+
)
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
self.resnets = nn.ModuleList(resnets)
|
|
1321
|
+
|
|
1322
|
+
if add_upsample:
|
|
1323
|
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
|
1324
|
+
else:
|
|
1325
|
+
self.upsamplers = None
|
|
1326
|
+
|
|
1327
|
+
self.resolution_idx = resolution_idx
|
|
1328
|
+
|
|
1329
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1330
|
+
for resnet in self.resnets:
|
|
1331
|
+
hidden_states = resnet(hidden_states, temb=temb)
|
|
1332
|
+
|
|
1333
|
+
if self.upsamplers is not None:
|
|
1334
|
+
for upsampler in self.upsamplers:
|
|
1335
|
+
hidden_states = upsampler(hidden_states)
|
|
1336
|
+
|
|
1337
|
+
return hidden_states
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
class Encoder(nn.Module):
|
|
1341
|
+
r"""
|
|
1342
|
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
|
1343
|
+
|
|
1344
|
+
Args:
|
|
1345
|
+
in_channels (`int`, *optional*, defaults to 3):
|
|
1346
|
+
The number of input channels.
|
|
1347
|
+
out_channels (`int`, *optional*, defaults to 3):
|
|
1348
|
+
The number of output channels.
|
|
1349
|
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
|
1350
|
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
|
1351
|
+
options.
|
|
1352
|
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
|
1353
|
+
The number of output channels for each block.
|
|
1354
|
+
layers_per_block (`int`, *optional*, defaults to 2):
|
|
1355
|
+
The number of layers per block.
|
|
1356
|
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
|
1357
|
+
The number of groups for normalization.
|
|
1358
|
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
|
1359
|
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
|
1360
|
+
double_z (`bool`, *optional*, defaults to `True`):
|
|
1361
|
+
Whether to double the number of output channels for the last block.
|
|
1362
|
+
"""
|
|
1363
|
+
|
|
1364
|
+
def __init__(
|
|
1365
|
+
self,
|
|
1366
|
+
in_channels: int = 3,
|
|
1367
|
+
out_channels: int = 3,
|
|
1368
|
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
|
1369
|
+
block_out_channels: Tuple[int, ...] = (64,),
|
|
1370
|
+
layers_per_block: int = 2,
|
|
1371
|
+
norm_num_groups: int = 32,
|
|
1372
|
+
act_fn: str = "silu",
|
|
1373
|
+
double_z: bool = True,
|
|
1374
|
+
mid_block_add_attention=True,
|
|
1375
|
+
):
|
|
1376
|
+
super().__init__()
|
|
1377
|
+
self.layers_per_block = layers_per_block
|
|
1378
|
+
|
|
1379
|
+
self.conv_in = nn.Conv2d(
|
|
1380
|
+
in_channels,
|
|
1381
|
+
block_out_channels[0],
|
|
1382
|
+
kernel_size=3,
|
|
1383
|
+
stride=1,
|
|
1384
|
+
padding=1,
|
|
1385
|
+
)
|
|
1386
|
+
|
|
1387
|
+
self.down_blocks = nn.ModuleList([])
|
|
1388
|
+
|
|
1389
|
+
# down
|
|
1390
|
+
output_channel = block_out_channels[0]
|
|
1391
|
+
for i, down_block_type in enumerate(down_block_types):
|
|
1392
|
+
input_channel = output_channel
|
|
1393
|
+
output_channel = block_out_channels[i]
|
|
1394
|
+
is_final_block = i == len(block_out_channels) - 1
|
|
1395
|
+
|
|
1396
|
+
down_block = DownEncoderBlock2D(
|
|
1397
|
+
num_layers=self.layers_per_block,
|
|
1398
|
+
in_channels=input_channel,
|
|
1399
|
+
out_channels=output_channel,
|
|
1400
|
+
add_downsample=not is_final_block,
|
|
1401
|
+
resnet_eps=1e-6,
|
|
1402
|
+
downsample_padding=0,
|
|
1403
|
+
resnet_act_fn=act_fn,
|
|
1404
|
+
resnet_groups=norm_num_groups,
|
|
1405
|
+
# attention_head_dim=output_channel,
|
|
1406
|
+
# temb_channels=None,
|
|
1407
|
+
)
|
|
1408
|
+
self.down_blocks.append(down_block)
|
|
1409
|
+
|
|
1410
|
+
# mid
|
|
1411
|
+
self.mid_block = UNetMidBlock2D(
|
|
1412
|
+
in_channels=block_out_channels[-1],
|
|
1413
|
+
resnet_eps=1e-6,
|
|
1414
|
+
resnet_act_fn=act_fn,
|
|
1415
|
+
output_scale_factor=1,
|
|
1416
|
+
resnet_time_scale_shift="default",
|
|
1417
|
+
attention_head_dim=block_out_channels[-1],
|
|
1418
|
+
resnet_groups=norm_num_groups,
|
|
1419
|
+
temb_channels=None,
|
|
1420
|
+
add_attention=mid_block_add_attention,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
# out
|
|
1424
|
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
|
1425
|
+
self.conv_act = nn.SiLU()
|
|
1426
|
+
|
|
1427
|
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
|
1428
|
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
|
1429
|
+
|
|
1430
|
+
self.gradient_checkpointing = False
|
|
1431
|
+
|
|
1432
|
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
|
1433
|
+
r"""The forward method of the `Encoder` class."""
|
|
1434
|
+
|
|
1435
|
+
sample = self.conv_in(sample)
|
|
1436
|
+
|
|
1437
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
1438
|
+
# down
|
|
1439
|
+
for down_block in self.down_blocks:
|
|
1440
|
+
sample = self._gradient_checkpointing_func(down_block, sample)
|
|
1441
|
+
# middle
|
|
1442
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
|
1443
|
+
|
|
1444
|
+
else:
|
|
1445
|
+
# down
|
|
1446
|
+
for down_block in self.down_blocks:
|
|
1447
|
+
sample = down_block(sample)
|
|
1448
|
+
|
|
1449
|
+
# middle
|
|
1450
|
+
sample = self.mid_block(sample)
|
|
1451
|
+
|
|
1452
|
+
# post-process
|
|
1453
|
+
sample = self.conv_norm_out(sample)
|
|
1454
|
+
sample = self.conv_act(sample)
|
|
1455
|
+
sample = self.conv_out(sample)
|
|
1456
|
+
|
|
1457
|
+
return sample
|
|
1458
|
+
|
|
1459
|
+
|
|
1460
|
+
class Decoder(nn.Module):
|
|
1461
|
+
r"""
|
|
1462
|
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
|
1463
|
+
|
|
1464
|
+
Args:
|
|
1465
|
+
in_channels (`int`, *optional*, defaults to 3):
|
|
1466
|
+
The number of input channels.
|
|
1467
|
+
out_channels (`int`, *optional*, defaults to 3):
|
|
1468
|
+
The number of output channels.
|
|
1469
|
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
|
1470
|
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
|
1471
|
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
|
1472
|
+
The number of output channels for each block.
|
|
1473
|
+
layers_per_block (`int`, *optional*, defaults to 2):
|
|
1474
|
+
The number of layers per block.
|
|
1475
|
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
|
1476
|
+
The number of groups for normalization.
|
|
1477
|
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
|
1478
|
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
|
1479
|
+
norm_type (`str`, *optional*, defaults to `"group"`):
|
|
1480
|
+
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
|
1481
|
+
"""
|
|
1482
|
+
|
|
1483
|
+
def __init__(
|
|
1484
|
+
self,
|
|
1485
|
+
in_channels: int = 3,
|
|
1486
|
+
out_channels: int = 3,
|
|
1487
|
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
|
1488
|
+
block_out_channels: Tuple[int, ...] = (64,),
|
|
1489
|
+
layers_per_block: int = 2,
|
|
1490
|
+
norm_num_groups: int = 32,
|
|
1491
|
+
act_fn: str = "silu",
|
|
1492
|
+
norm_type: str = "group", # group, spatial
|
|
1493
|
+
mid_block_add_attention=True,
|
|
1494
|
+
):
|
|
1495
|
+
super().__init__()
|
|
1496
|
+
self.layers_per_block = layers_per_block
|
|
1497
|
+
|
|
1498
|
+
self.conv_in = nn.Conv2d(
|
|
1499
|
+
in_channels,
|
|
1500
|
+
block_out_channels[-1],
|
|
1501
|
+
kernel_size=3,
|
|
1502
|
+
stride=1,
|
|
1503
|
+
padding=1,
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
self.up_blocks = nn.ModuleList([])
|
|
1507
|
+
|
|
1508
|
+
temb_channels = in_channels if norm_type == "spatial" else None
|
|
1509
|
+
|
|
1510
|
+
# mid
|
|
1511
|
+
self.mid_block = UNetMidBlock2D(
|
|
1512
|
+
in_channels=block_out_channels[-1],
|
|
1513
|
+
resnet_eps=1e-6,
|
|
1514
|
+
resnet_act_fn=act_fn,
|
|
1515
|
+
output_scale_factor=1,
|
|
1516
|
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
|
1517
|
+
attention_head_dim=block_out_channels[-1],
|
|
1518
|
+
resnet_groups=norm_num_groups,
|
|
1519
|
+
temb_channels=temb_channels,
|
|
1520
|
+
add_attention=mid_block_add_attention,
|
|
1521
|
+
)
|
|
1522
|
+
|
|
1523
|
+
# up
|
|
1524
|
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
1525
|
+
output_channel = reversed_block_out_channels[0]
|
|
1526
|
+
for i, up_block_type in enumerate(up_block_types):
|
|
1527
|
+
prev_output_channel = output_channel
|
|
1528
|
+
output_channel = reversed_block_out_channels[i]
|
|
1529
|
+
|
|
1530
|
+
is_final_block = i == len(block_out_channels) - 1
|
|
1531
|
+
|
|
1532
|
+
up_block = UpDecoderBlock2D(
|
|
1533
|
+
num_layers=self.layers_per_block + 1,
|
|
1534
|
+
in_channels=prev_output_channel,
|
|
1535
|
+
out_channels=output_channel,
|
|
1536
|
+
# prev_output_channel=prev_output_channel,
|
|
1537
|
+
add_upsample=not is_final_block,
|
|
1538
|
+
resnet_eps=1e-6,
|
|
1539
|
+
resnet_act_fn=act_fn,
|
|
1540
|
+
resnet_groups=norm_num_groups,
|
|
1541
|
+
# attention_head_dim=output_channel,
|
|
1542
|
+
temb_channels=temb_channels,
|
|
1543
|
+
resnet_time_scale_shift=norm_type,
|
|
1544
|
+
)
|
|
1545
|
+
self.up_blocks.append(up_block)
|
|
1546
|
+
prev_output_channel = output_channel
|
|
1547
|
+
|
|
1548
|
+
# out
|
|
1549
|
+
if norm_type == "spatial":
|
|
1550
|
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
|
1551
|
+
else:
|
|
1552
|
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
|
1553
|
+
self.conv_act = nn.SiLU()
|
|
1554
|
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
|
1555
|
+
|
|
1556
|
+
self.gradient_checkpointing = False
|
|
1557
|
+
|
|
1558
|
+
def forward(
|
|
1559
|
+
self,
|
|
1560
|
+
sample: torch.Tensor,
|
|
1561
|
+
latent_embeds: Optional[torch.Tensor] = None,
|
|
1562
|
+
) -> torch.Tensor:
|
|
1563
|
+
r"""The forward method of the `Decoder` class."""
|
|
1564
|
+
|
|
1565
|
+
sample = self.conv_in(sample)
|
|
1566
|
+
|
|
1567
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
1568
|
+
# middle
|
|
1569
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
|
1570
|
+
|
|
1571
|
+
# up
|
|
1572
|
+
for up_block in self.up_blocks:
|
|
1573
|
+
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
|
|
1574
|
+
else:
|
|
1575
|
+
# middle
|
|
1576
|
+
sample = self.mid_block(sample, latent_embeds)
|
|
1577
|
+
|
|
1578
|
+
# up
|
|
1579
|
+
for up_block in self.up_blocks:
|
|
1580
|
+
sample = up_block(sample, latent_embeds)
|
|
1581
|
+
|
|
1582
|
+
# post-process
|
|
1583
|
+
if latent_embeds is None:
|
|
1584
|
+
sample = self.conv_norm_out(sample)
|
|
1585
|
+
else:
|
|
1586
|
+
sample = self.conv_norm_out(sample, latent_embeds)
|
|
1587
|
+
sample = self.conv_act(sample)
|
|
1588
|
+
sample = self.conv_out(sample)
|
|
1589
|
+
|
|
1590
|
+
return sample
|
|
1591
|
+
|
|
1592
|
+
|
|
1593
|
+
class Flux2VAE(PreTrainedModel):
|
|
1594
|
+
r"""
|
|
1595
|
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
|
1596
|
+
|
|
1597
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
|
1598
|
+
for all models (such as downloading or saving).
|
|
1599
|
+
|
|
1600
|
+
Parameters:
|
|
1601
|
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
|
1602
|
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
|
1603
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
|
1604
|
+
Tuple of downsample block types.
|
|
1605
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
|
1606
|
+
Tuple of upsample block types.
|
|
1607
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
|
1608
|
+
Tuple of block output channels.
|
|
1609
|
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
|
1610
|
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
|
1611
|
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
|
1612
|
+
force_upcast (`bool`, *optional*, default to `True`):
|
|
1613
|
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
|
1614
|
+
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
|
1615
|
+
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
|
1616
|
+
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
|
1617
|
+
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
|
1618
|
+
mid_block will only have resnet blocks
|
|
1619
|
+
"""
|
|
1620
|
+
|
|
1621
|
+
_supports_gradient_checkpointing = True
|
|
1622
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
|
1623
|
+
|
|
1624
|
+
def __init__(
|
|
1625
|
+
self,
|
|
1626
|
+
in_channels: int = 3,
|
|
1627
|
+
out_channels: int = 3,
|
|
1628
|
+
down_block_types: Tuple[str, ...] = (
|
|
1629
|
+
"DownEncoderBlock2D",
|
|
1630
|
+
"DownEncoderBlock2D",
|
|
1631
|
+
"DownEncoderBlock2D",
|
|
1632
|
+
"DownEncoderBlock2D",
|
|
1633
|
+
),
|
|
1634
|
+
up_block_types: Tuple[str, ...] = (
|
|
1635
|
+
"UpDecoderBlock2D",
|
|
1636
|
+
"UpDecoderBlock2D",
|
|
1637
|
+
"UpDecoderBlock2D",
|
|
1638
|
+
"UpDecoderBlock2D",
|
|
1639
|
+
),
|
|
1640
|
+
block_out_channels: Tuple[int, ...] = (
|
|
1641
|
+
128,
|
|
1642
|
+
256,
|
|
1643
|
+
512,
|
|
1644
|
+
512,
|
|
1645
|
+
),
|
|
1646
|
+
layers_per_block: int = 2,
|
|
1647
|
+
act_fn: str = "silu",
|
|
1648
|
+
latent_channels: int = 32,
|
|
1649
|
+
norm_num_groups: int = 32,
|
|
1650
|
+
sample_size: int = 1024, # YiYi notes: not sure
|
|
1651
|
+
force_upcast: bool = True,
|
|
1652
|
+
use_quant_conv: bool = True,
|
|
1653
|
+
use_post_quant_conv: bool = True,
|
|
1654
|
+
mid_block_add_attention: bool = True,
|
|
1655
|
+
batch_norm_eps: float = 1e-4,
|
|
1656
|
+
batch_norm_momentum: float = 0.1,
|
|
1657
|
+
patch_size: Tuple[int, int] = (2, 2),
|
|
1658
|
+
device: str = "cuda:0",
|
|
1659
|
+
dtype: torch.dtype = torch.float32,
|
|
1660
|
+
):
|
|
1661
|
+
super().__init__()
|
|
1662
|
+
|
|
1663
|
+
# pass init params to Encoder
|
|
1664
|
+
self.encoder = Encoder(
|
|
1665
|
+
in_channels=in_channels,
|
|
1666
|
+
out_channels=latent_channels,
|
|
1667
|
+
down_block_types=down_block_types,
|
|
1668
|
+
block_out_channels=block_out_channels,
|
|
1669
|
+
layers_per_block=layers_per_block,
|
|
1670
|
+
act_fn=act_fn,
|
|
1671
|
+
norm_num_groups=norm_num_groups,
|
|
1672
|
+
double_z=True,
|
|
1673
|
+
mid_block_add_attention=mid_block_add_attention,
|
|
1674
|
+
)
|
|
1675
|
+
|
|
1676
|
+
# pass init params to Decoder
|
|
1677
|
+
self.decoder = Decoder(
|
|
1678
|
+
in_channels=latent_channels,
|
|
1679
|
+
out_channels=out_channels,
|
|
1680
|
+
up_block_types=up_block_types,
|
|
1681
|
+
block_out_channels=block_out_channels,
|
|
1682
|
+
layers_per_block=layers_per_block,
|
|
1683
|
+
norm_num_groups=norm_num_groups,
|
|
1684
|
+
act_fn=act_fn,
|
|
1685
|
+
mid_block_add_attention=mid_block_add_attention,
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
|
1689
|
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
|
1690
|
+
|
|
1691
|
+
self.bn = nn.BatchNorm2d(
|
|
1692
|
+
math.prod(patch_size) * latent_channels,
|
|
1693
|
+
eps=batch_norm_eps,
|
|
1694
|
+
momentum=batch_norm_momentum,
|
|
1695
|
+
affine=False,
|
|
1696
|
+
track_running_stats=True,
|
|
1697
|
+
)
|
|
1698
|
+
|
|
1699
|
+
self.use_slicing = False
|
|
1700
|
+
self.use_tiling = False
|
|
1701
|
+
|
|
1702
|
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
1703
|
+
batch_size, num_channels, height, width = x.shape
|
|
1704
|
+
|
|
1705
|
+
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
|
1706
|
+
return self._tiled_encode(x)
|
|
1707
|
+
|
|
1708
|
+
enc = self.encoder(x)
|
|
1709
|
+
if self.quant_conv is not None:
|
|
1710
|
+
enc = self.quant_conv(enc)
|
|
1711
|
+
|
|
1712
|
+
return enc
|
|
1713
|
+
|
|
1714
|
+
def encode(
|
|
1715
|
+
self, x: torch.Tensor, return_dict: bool = True
|
|
1716
|
+
):
|
|
1717
|
+
"""
|
|
1718
|
+
Encode a batch of images into latents.
|
|
1719
|
+
|
|
1720
|
+
Args:
|
|
1721
|
+
x (`torch.Tensor`): Input batch of images.
|
|
1722
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
1723
|
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
|
1724
|
+
|
|
1725
|
+
Returns:
|
|
1726
|
+
The latent representations of the encoded images. If `return_dict` is True, a
|
|
1727
|
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
|
1728
|
+
"""
|
|
1729
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
1730
|
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
|
1731
|
+
h = torch.cat(encoded_slices)
|
|
1732
|
+
else:
|
|
1733
|
+
h = self._encode(x)
|
|
1734
|
+
|
|
1735
|
+
h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2)
|
|
1736
|
+
h = h[:, :128]
|
|
1737
|
+
latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype)
|
|
1738
|
+
latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
|
|
1739
|
+
h.device, h.dtype
|
|
1740
|
+
)
|
|
1741
|
+
h = (h - latents_bn_mean) / latents_bn_std
|
|
1742
|
+
return h
|
|
1743
|
+
|
|
1744
|
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
|
1745
|
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
|
1746
|
+
return self.tiled_decode(z, return_dict=return_dict)
|
|
1747
|
+
|
|
1748
|
+
if self.post_quant_conv is not None:
|
|
1749
|
+
z = self.post_quant_conv(z)
|
|
1750
|
+
|
|
1751
|
+
dec = self.decoder(z)
|
|
1752
|
+
|
|
1753
|
+
if not return_dict:
|
|
1754
|
+
return (dec,)
|
|
1755
|
+
|
|
1756
|
+
return dec
|
|
1757
|
+
|
|
1758
|
+
def decode(
|
|
1759
|
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
|
1760
|
+
):
|
|
1761
|
+
latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype)
|
|
1762
|
+
latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
|
|
1763
|
+
z.device, z.dtype
|
|
1764
|
+
)
|
|
1765
|
+
z = z * latents_bn_std + latents_bn_mean
|
|
1766
|
+
z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2)
|
|
1767
|
+
"""
|
|
1768
|
+
Decode a batch of images.
|
|
1769
|
+
|
|
1770
|
+
Args:
|
|
1771
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
|
1772
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
1773
|
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
|
1774
|
+
|
|
1775
|
+
Returns:
|
|
1776
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
|
1777
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
|
1778
|
+
returned.
|
|
1779
|
+
|
|
1780
|
+
"""
|
|
1781
|
+
if self.use_slicing and z.shape[0] > 1:
|
|
1782
|
+
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
|
1783
|
+
decoded = torch.cat(decoded_slices)
|
|
1784
|
+
else:
|
|
1785
|
+
decoded = self._decode(z)
|
|
1786
|
+
|
|
1787
|
+
if not return_dict:
|
|
1788
|
+
return (decoded,)
|
|
1789
|
+
|
|
1790
|
+
return decoded
|
|
1791
|
+
|
|
1792
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
1793
|
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
|
1794
|
+
for y in range(blend_extent):
|
|
1795
|
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
|
1796
|
+
return b
|
|
1797
|
+
|
|
1798
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
1799
|
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
|
1800
|
+
for x in range(blend_extent):
|
|
1801
|
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
|
1802
|
+
return b
|
|
1803
|
+
|
|
1804
|
+
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
1805
|
+
r"""Encode a batch of images using a tiled encoder.
|
|
1806
|
+
|
|
1807
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
|
1808
|
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
|
1809
|
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
|
1810
|
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
|
1811
|
+
output, but they should be much less noticeable.
|
|
1812
|
+
|
|
1813
|
+
Args:
|
|
1814
|
+
x (`torch.Tensor`): Input batch of images.
|
|
1815
|
+
|
|
1816
|
+
Returns:
|
|
1817
|
+
`torch.Tensor`:
|
|
1818
|
+
The latent representation of the encoded videos.
|
|
1819
|
+
"""
|
|
1820
|
+
|
|
1821
|
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
1822
|
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
1823
|
+
row_limit = self.tile_latent_min_size - blend_extent
|
|
1824
|
+
|
|
1825
|
+
# Split the image into 512x512 tiles and encode them separately.
|
|
1826
|
+
rows = []
|
|
1827
|
+
for i in range(0, x.shape[2], overlap_size):
|
|
1828
|
+
row = []
|
|
1829
|
+
for j in range(0, x.shape[3], overlap_size):
|
|
1830
|
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
|
1831
|
+
tile = self.encoder(tile)
|
|
1832
|
+
if self.config.use_quant_conv:
|
|
1833
|
+
tile = self.quant_conv(tile)
|
|
1834
|
+
row.append(tile)
|
|
1835
|
+
rows.append(row)
|
|
1836
|
+
result_rows = []
|
|
1837
|
+
for i, row in enumerate(rows):
|
|
1838
|
+
result_row = []
|
|
1839
|
+
for j, tile in enumerate(row):
|
|
1840
|
+
# blend the above tile and the left tile
|
|
1841
|
+
# to the current tile and add the current tile to the result row
|
|
1842
|
+
if i > 0:
|
|
1843
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
1844
|
+
if j > 0:
|
|
1845
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
1846
|
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
1847
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
|
1848
|
+
|
|
1849
|
+
enc = torch.cat(result_rows, dim=2)
|
|
1850
|
+
return enc
|
|
1851
|
+
|
|
1852
|
+
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True):
|
|
1853
|
+
r"""Encode a batch of images using a tiled encoder.
|
|
1854
|
+
|
|
1855
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
|
1856
|
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
|
1857
|
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
|
1858
|
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
|
1859
|
+
output, but they should be much less noticeable.
|
|
1860
|
+
|
|
1861
|
+
Args:
|
|
1862
|
+
x (`torch.Tensor`): Input batch of images.
|
|
1863
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
1864
|
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
|
1865
|
+
|
|
1866
|
+
Returns:
|
|
1867
|
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
|
1868
|
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
|
1869
|
+
`tuple` is returned.
|
|
1870
|
+
"""
|
|
1871
|
+
|
|
1872
|
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
1873
|
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
1874
|
+
row_limit = self.tile_latent_min_size - blend_extent
|
|
1875
|
+
|
|
1876
|
+
# Split the image into 512x512 tiles and encode them separately.
|
|
1877
|
+
rows = []
|
|
1878
|
+
for i in range(0, x.shape[2], overlap_size):
|
|
1879
|
+
row = []
|
|
1880
|
+
for j in range(0, x.shape[3], overlap_size):
|
|
1881
|
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
|
1882
|
+
tile = self.encoder(tile)
|
|
1883
|
+
if self.config.use_quant_conv:
|
|
1884
|
+
tile = self.quant_conv(tile)
|
|
1885
|
+
row.append(tile)
|
|
1886
|
+
rows.append(row)
|
|
1887
|
+
result_rows = []
|
|
1888
|
+
for i, row in enumerate(rows):
|
|
1889
|
+
result_row = []
|
|
1890
|
+
for j, tile in enumerate(row):
|
|
1891
|
+
# blend the above tile and the left tile
|
|
1892
|
+
# to the current tile and add the current tile to the result row
|
|
1893
|
+
if i > 0:
|
|
1894
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
1895
|
+
if j > 0:
|
|
1896
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
1897
|
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
1898
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
|
1899
|
+
|
|
1900
|
+
moments = torch.cat(result_rows, dim=2)
|
|
1901
|
+
return moments
|
|
1902
|
+
|
|
1903
|
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
|
|
1904
|
+
r"""
|
|
1905
|
+
Decode a batch of images using a tiled decoder.
|
|
1906
|
+
|
|
1907
|
+
Args:
|
|
1908
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
|
1909
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
1910
|
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
|
1911
|
+
|
|
1912
|
+
Returns:
|
|
1913
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
|
1914
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
|
1915
|
+
returned.
|
|
1916
|
+
"""
|
|
1917
|
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
|
1918
|
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
|
1919
|
+
row_limit = self.tile_sample_min_size - blend_extent
|
|
1920
|
+
|
|
1921
|
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
|
1922
|
+
# The tiles have an overlap to avoid seams between tiles.
|
|
1923
|
+
rows = []
|
|
1924
|
+
for i in range(0, z.shape[2], overlap_size):
|
|
1925
|
+
row = []
|
|
1926
|
+
for j in range(0, z.shape[3], overlap_size):
|
|
1927
|
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
|
1928
|
+
if self.config.use_post_quant_conv:
|
|
1929
|
+
tile = self.post_quant_conv(tile)
|
|
1930
|
+
decoded = self.decoder(tile)
|
|
1931
|
+
row.append(decoded)
|
|
1932
|
+
rows.append(row)
|
|
1933
|
+
result_rows = []
|
|
1934
|
+
for i, row in enumerate(rows):
|
|
1935
|
+
result_row = []
|
|
1936
|
+
for j, tile in enumerate(row):
|
|
1937
|
+
# blend the above tile and the left tile
|
|
1938
|
+
# to the current tile and add the current tile to the result row
|
|
1939
|
+
if i > 0:
|
|
1940
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
1941
|
+
if j > 0:
|
|
1942
|
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
1943
|
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
1944
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
|
1945
|
+
|
|
1946
|
+
dec = torch.cat(result_rows, dim=2)
|
|
1947
|
+
if not return_dict:
|
|
1948
|
+
return (dec,)
|
|
1949
|
+
|
|
1950
|
+
return dec
|
|
1951
|
+
|
|
1952
|
+
def forward(
|
|
1953
|
+
self,
|
|
1954
|
+
sample: torch.Tensor,
|
|
1955
|
+
sample_posterior: bool = False,
|
|
1956
|
+
return_dict: bool = True,
|
|
1957
|
+
generator: Optional[torch.Generator] = None,
|
|
1958
|
+
):
|
|
1959
|
+
r"""
|
|
1960
|
+
Args:
|
|
1961
|
+
sample (`torch.Tensor`): Input sample.
|
|
1962
|
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
|
1963
|
+
Whether to sample from the posterior.
|
|
1964
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
1965
|
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
|
1966
|
+
"""
|
|
1967
|
+
x = sample
|
|
1968
|
+
posterior = self.encode(x).latent_dist
|
|
1969
|
+
if sample_posterior:
|
|
1970
|
+
z = posterior.sample(generator=generator)
|
|
1971
|
+
else:
|
|
1972
|
+
z = posterior.mode()
|
|
1973
|
+
dec = self.decode(z).sample
|
|
1974
|
+
|
|
1975
|
+
if not return_dict:
|
|
1976
|
+
return (dec,)
|
|
1977
|
+
|
|
1978
|
+
return dec
|
|
1979
|
+
|
|
1980
|
+
@classmethod
|
|
1981
|
+
def from_state_dict(
|
|
1982
|
+
cls,
|
|
1983
|
+
state_dict: Dict[str, torch.Tensor],
|
|
1984
|
+
device: str = "cuda:0",
|
|
1985
|
+
dtype: torch.dtype = torch.float32,
|
|
1986
|
+
**kwargs,
|
|
1987
|
+
) -> "Flux2VAE":
|
|
1988
|
+
model = cls(device="meta", dtype=dtype, **kwargs)
|
|
1989
|
+
model = model.requires_grad_(False)
|
|
1990
|
+
model.load_state_dict(state_dict, assign=True)
|
|
1991
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
1992
|
+
return model
|