deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,997 @@
|
|
|
1
|
+
# This file is a concatenation of DiffPIR codes available here: https://github.com/yuanzhi-zhu/DiffPIR/tree/main
|
|
2
|
+
# This code is taken (with minor modifications) from https://github.com/yuanzhi-zhu/DiffPIR/tree/main
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from .utils import get_weights_url
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DiffUNet(nn.Module):
|
|
13
|
+
r"""
|
|
14
|
+
Diffusion UNet model.
|
|
15
|
+
|
|
16
|
+
This is the model with attention and timestep embeddings from `Ho et al. <https://arxiv.org/abs/2108.02938>`_;
|
|
17
|
+
code is adapted from https://github.com/jychoi118/ilvr_adm.
|
|
18
|
+
|
|
19
|
+
It is possible to choose the `standard model <https://arxiv.org/abs/2108.02938>`_
|
|
20
|
+
with 128 hidden channels per layer (trained on FFHQ)
|
|
21
|
+
and a `larger model <https://arxiv.org/abs/2105.05233>`_ with 256 hidden channels per layer (trained on ImageNet128)
|
|
22
|
+
|
|
23
|
+
A pretrained network for (in_channels=out_channels=3)
|
|
24
|
+
can be downloaded via setting ``pretrained='download'``.
|
|
25
|
+
|
|
26
|
+
The network can handle images of size :math:`2^{n_1}\times 2^{n_2}` with :math:`n_1,n_2 \geq 5`.
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
:param int in_channels: channels in the input Tensor.
|
|
30
|
+
:param int out_channels: channels in the output Tensor.
|
|
31
|
+
:param bool large_model: if True, use the large model with 256 hidden channels per layer trained on ImageNet128
|
|
32
|
+
(weights size: 2.1 GB).
|
|
33
|
+
Otherwise, use a smaller model with 128 hidden channels per layer trained on FFHQ (weights size: 357 MB).
|
|
34
|
+
:param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at
|
|
35
|
+
random using Pytorch's default initialization.
|
|
36
|
+
If ``pretrained='download'``, the weights will be downloaded from an online repository
|
|
37
|
+
(only available for 3 input and output channels).
|
|
38
|
+
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
|
|
39
|
+
See :ref:`pretrained-weights <pretrained-weights>` for more details.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
in_channels=3,
|
|
45
|
+
out_channels=3,
|
|
46
|
+
large_model=False,
|
|
47
|
+
use_fp16=False,
|
|
48
|
+
pretrained="download",
|
|
49
|
+
):
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
if large_model:
|
|
53
|
+
model_channels = 256
|
|
54
|
+
num_res_blocks = 2
|
|
55
|
+
attention_resolutions = "8,16,32"
|
|
56
|
+
else:
|
|
57
|
+
model_channels = 128
|
|
58
|
+
num_res_blocks = 1
|
|
59
|
+
attention_resolutions = "16"
|
|
60
|
+
|
|
61
|
+
dropout = 0.1
|
|
62
|
+
conv_resample = True
|
|
63
|
+
dims = 2
|
|
64
|
+
num_classes = None
|
|
65
|
+
use_checkpoint = False
|
|
66
|
+
num_heads = 4
|
|
67
|
+
num_head_channels = 64
|
|
68
|
+
num_heads_upsample = -1
|
|
69
|
+
use_scale_shift_norm = True
|
|
70
|
+
resblock_updown = True
|
|
71
|
+
use_new_attention_order = False
|
|
72
|
+
|
|
73
|
+
out_channels = 6 if out_channels == 3 else out_channels
|
|
74
|
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
|
75
|
+
|
|
76
|
+
image_size = 256
|
|
77
|
+
attention_ds = []
|
|
78
|
+
for res in attention_resolutions.split(","):
|
|
79
|
+
attention_ds.append(image_size // int(res))
|
|
80
|
+
attention_resolutions = tuple(attention_ds)
|
|
81
|
+
|
|
82
|
+
if num_heads_upsample == -1:
|
|
83
|
+
num_heads_upsample = num_heads
|
|
84
|
+
|
|
85
|
+
self.image_size = image_size
|
|
86
|
+
self.in_channels = in_channels
|
|
87
|
+
self.model_channels = model_channels
|
|
88
|
+
self.out_channels = out_channels
|
|
89
|
+
self.num_res_blocks = num_res_blocks
|
|
90
|
+
self.attention_resolutions = attention_resolutions
|
|
91
|
+
self.dropout = dropout
|
|
92
|
+
self.channel_mult = channel_mult
|
|
93
|
+
self.conv_resample = conv_resample
|
|
94
|
+
self.num_classes = num_classes
|
|
95
|
+
self.use_checkpoint = use_checkpoint
|
|
96
|
+
self.dtype = th.float16 if use_fp16 else th.float32
|
|
97
|
+
self.num_heads = num_heads
|
|
98
|
+
self.num_head_channels = num_head_channels
|
|
99
|
+
self.num_heads_upsample = num_heads_upsample
|
|
100
|
+
|
|
101
|
+
time_embed_dim = model_channels * 4
|
|
102
|
+
self.time_embed = nn.Sequential(
|
|
103
|
+
linear(model_channels, time_embed_dim),
|
|
104
|
+
nn.SiLU(),
|
|
105
|
+
linear(time_embed_dim, time_embed_dim),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if self.num_classes is not None:
|
|
109
|
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
|
110
|
+
|
|
111
|
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
|
112
|
+
self.input_blocks = nn.ModuleList(
|
|
113
|
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
|
114
|
+
)
|
|
115
|
+
self._feature_size = ch
|
|
116
|
+
input_block_chans = [ch]
|
|
117
|
+
ds = 1
|
|
118
|
+
for level, mult in enumerate(channel_mult):
|
|
119
|
+
for _ in range(num_res_blocks):
|
|
120
|
+
layers = [
|
|
121
|
+
ResBlock(
|
|
122
|
+
ch,
|
|
123
|
+
time_embed_dim,
|
|
124
|
+
dropout,
|
|
125
|
+
out_channels=int(mult * model_channels),
|
|
126
|
+
dims=dims,
|
|
127
|
+
use_checkpoint=use_checkpoint,
|
|
128
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
129
|
+
)
|
|
130
|
+
]
|
|
131
|
+
ch = int(mult * model_channels)
|
|
132
|
+
if ds in attention_resolutions:
|
|
133
|
+
layers.append(
|
|
134
|
+
AttentionBlock(
|
|
135
|
+
ch,
|
|
136
|
+
use_checkpoint=use_checkpoint,
|
|
137
|
+
num_heads=num_heads,
|
|
138
|
+
num_head_channels=num_head_channels,
|
|
139
|
+
use_new_attention_order=use_new_attention_order,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
143
|
+
self._feature_size += ch
|
|
144
|
+
input_block_chans.append(ch)
|
|
145
|
+
if level != len(channel_mult) - 1:
|
|
146
|
+
out_ch = ch
|
|
147
|
+
self.input_blocks.append(
|
|
148
|
+
TimestepEmbedSequential(
|
|
149
|
+
ResBlock(
|
|
150
|
+
ch,
|
|
151
|
+
time_embed_dim,
|
|
152
|
+
dropout,
|
|
153
|
+
out_channels=out_ch,
|
|
154
|
+
dims=dims,
|
|
155
|
+
use_checkpoint=use_checkpoint,
|
|
156
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
157
|
+
down=True,
|
|
158
|
+
)
|
|
159
|
+
if resblock_updown
|
|
160
|
+
else Downsample(
|
|
161
|
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
ch = out_ch
|
|
166
|
+
input_block_chans.append(ch)
|
|
167
|
+
ds *= 2
|
|
168
|
+
self._feature_size += ch
|
|
169
|
+
|
|
170
|
+
self.middle_block = TimestepEmbedSequential(
|
|
171
|
+
ResBlock(
|
|
172
|
+
ch,
|
|
173
|
+
time_embed_dim,
|
|
174
|
+
dropout,
|
|
175
|
+
dims=dims,
|
|
176
|
+
use_checkpoint=use_checkpoint,
|
|
177
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
178
|
+
),
|
|
179
|
+
AttentionBlock(
|
|
180
|
+
ch,
|
|
181
|
+
use_checkpoint=use_checkpoint,
|
|
182
|
+
num_heads=num_heads,
|
|
183
|
+
num_head_channels=num_head_channels,
|
|
184
|
+
use_new_attention_order=use_new_attention_order,
|
|
185
|
+
),
|
|
186
|
+
ResBlock(
|
|
187
|
+
ch,
|
|
188
|
+
time_embed_dim,
|
|
189
|
+
dropout,
|
|
190
|
+
dims=dims,
|
|
191
|
+
use_checkpoint=use_checkpoint,
|
|
192
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
193
|
+
),
|
|
194
|
+
)
|
|
195
|
+
self._feature_size += ch
|
|
196
|
+
|
|
197
|
+
self.output_blocks = nn.ModuleList([])
|
|
198
|
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
|
199
|
+
for i in range(num_res_blocks + 1):
|
|
200
|
+
ich = input_block_chans.pop()
|
|
201
|
+
layers = [
|
|
202
|
+
ResBlock(
|
|
203
|
+
ch + ich,
|
|
204
|
+
time_embed_dim,
|
|
205
|
+
dropout,
|
|
206
|
+
out_channels=int(model_channels * mult),
|
|
207
|
+
dims=dims,
|
|
208
|
+
use_checkpoint=use_checkpoint,
|
|
209
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
210
|
+
)
|
|
211
|
+
]
|
|
212
|
+
ch = int(model_channels * mult)
|
|
213
|
+
if ds in attention_resolutions:
|
|
214
|
+
layers.append(
|
|
215
|
+
AttentionBlock(
|
|
216
|
+
ch,
|
|
217
|
+
use_checkpoint=use_checkpoint,
|
|
218
|
+
num_heads=num_heads_upsample,
|
|
219
|
+
num_head_channels=num_head_channels,
|
|
220
|
+
use_new_attention_order=use_new_attention_order,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
if level and i == num_res_blocks:
|
|
224
|
+
out_ch = ch
|
|
225
|
+
layers.append(
|
|
226
|
+
ResBlock(
|
|
227
|
+
ch,
|
|
228
|
+
time_embed_dim,
|
|
229
|
+
dropout,
|
|
230
|
+
out_channels=out_ch,
|
|
231
|
+
dims=dims,
|
|
232
|
+
use_checkpoint=use_checkpoint,
|
|
233
|
+
use_scale_shift_norm=use_scale_shift_norm,
|
|
234
|
+
up=True,
|
|
235
|
+
)
|
|
236
|
+
if resblock_updown
|
|
237
|
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
|
238
|
+
)
|
|
239
|
+
ds //= 2
|
|
240
|
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
|
241
|
+
self._feature_size += ch
|
|
242
|
+
|
|
243
|
+
self.out = nn.Sequential(
|
|
244
|
+
normalization(ch),
|
|
245
|
+
nn.SiLU(),
|
|
246
|
+
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if pretrained is not None:
|
|
250
|
+
if pretrained == "download":
|
|
251
|
+
if in_channels == 3 and out_channels == 6 and not large_model:
|
|
252
|
+
name = "diffusion_ffhq_10m.pt"
|
|
253
|
+
elif in_channels == 3 and out_channels == 6 and large_model:
|
|
254
|
+
name = "diffusion_openai.pt"
|
|
255
|
+
else:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"no existing pretrained model matches the requested configuration"
|
|
258
|
+
)
|
|
259
|
+
url = get_weights_url(model_name="diffunet", file_name=name)
|
|
260
|
+
ckpt = torch.hub.load_state_dict_from_url(
|
|
261
|
+
url, map_location=lambda storage, loc: storage, file_name=name
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
ckpt = torch.load(pretrained, map_location=lambda storage, loc: storage)
|
|
265
|
+
|
|
266
|
+
self.load_state_dict(ckpt, strict=True)
|
|
267
|
+
|
|
268
|
+
def convert_to_fp16(self):
|
|
269
|
+
"""
|
|
270
|
+
Convert the torso of the model to float16.
|
|
271
|
+
"""
|
|
272
|
+
self.input_blocks.apply(convert_module_to_f16)
|
|
273
|
+
self.middle_block.apply(convert_module_to_f16)
|
|
274
|
+
self.output_blocks.apply(convert_module_to_f16)
|
|
275
|
+
|
|
276
|
+
def convert_to_fp32(self):
|
|
277
|
+
"""
|
|
278
|
+
Convert the torso of the model to float32.
|
|
279
|
+
"""
|
|
280
|
+
self.input_blocks.apply(convert_module_to_f32)
|
|
281
|
+
self.middle_block.apply(convert_module_to_f32)
|
|
282
|
+
self.output_blocks.apply(convert_module_to_f32)
|
|
283
|
+
|
|
284
|
+
def forward(self, x, t, y=None, type_t="noise_level"):
|
|
285
|
+
r"""
|
|
286
|
+
Apply the model to an input batch.
|
|
287
|
+
|
|
288
|
+
This function takes a noisy image and either a timestep or a noise level as input. Depending on the nature of
|
|
289
|
+
``t``, the model returns either a noise map (if ``type_t='timestep'``) or a denoised image (if
|
|
290
|
+
``type_t='noise_level'``).
|
|
291
|
+
|
|
292
|
+
:param x: an [N x C x ...] Tensor of inputs.
|
|
293
|
+
:param t: a 1-D batch of timesteps or noise levels.
|
|
294
|
+
:param y: an [N] Tensor of labels, if class-conditional. Default=None.
|
|
295
|
+
:param type_t: Nature of the embedding `t`. In traditional diffusion model, and in the authors' code, `t` is
|
|
296
|
+
a timestep linked to a noise level; in this case, set ``type_t='timestep'``. We can also choose
|
|
297
|
+
``t`` to be a noise level directly and use the model as a denoiser; in this case, set
|
|
298
|
+
``type_t='noise_level'``. Default: ``'timestep'``.
|
|
299
|
+
:return: an [N x C x ...] Tensor of outputs. Either a noise map (if ``type_t='timestep'``) or a denoised image
|
|
300
|
+
(if ``type_t='noise_level'``).
|
|
301
|
+
"""
|
|
302
|
+
if type_t == "timestep":
|
|
303
|
+
return self.forward_diffusion(x, t, y=y)
|
|
304
|
+
elif type_t == "noise_level":
|
|
305
|
+
return self.forward_denoise(x, t, y=y)
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError('type_t must be either "timestep" or "noise_level"')
|
|
308
|
+
|
|
309
|
+
def forward_diffusion(self, x, timesteps, y=None):
|
|
310
|
+
r"""
|
|
311
|
+
Apply the model to an input batch.
|
|
312
|
+
|
|
313
|
+
This function takes a noisy image and a timestep as input (and not a noise level) and estimates the noise map
|
|
314
|
+
in the input image.
|
|
315
|
+
The image is assumed to be in range [-1, 1] and to have dimensions with width and height divisible by a
|
|
316
|
+
power of 2.
|
|
317
|
+
|
|
318
|
+
:param x: an [N x C x ...] Tensor of inputs.
|
|
319
|
+
:param timesteps: a 1-D batch of timesteps.
|
|
320
|
+
:param y: an [N] Tensor of labels, if class-conditional. Default=None.
|
|
321
|
+
:return: an [N x C x ...] Tensor of outputs.
|
|
322
|
+
"""
|
|
323
|
+
assert (y is not None) == (
|
|
324
|
+
self.num_classes is not None
|
|
325
|
+
), "must specify y if and only if the model is class-conditional"
|
|
326
|
+
|
|
327
|
+
hs = []
|
|
328
|
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
329
|
+
|
|
330
|
+
if self.num_classes is not None:
|
|
331
|
+
assert y.shape == (x.shape[0],)
|
|
332
|
+
emb = emb + self.label_emb(y)
|
|
333
|
+
|
|
334
|
+
h = x.type(self.dtype)
|
|
335
|
+
for module in self.input_blocks:
|
|
336
|
+
h = module(h, emb)
|
|
337
|
+
hs.append(h)
|
|
338
|
+
h = self.middle_block(h, emb)
|
|
339
|
+
for module in self.output_blocks:
|
|
340
|
+
h = th.cat([h, hs.pop()], dim=1)
|
|
341
|
+
h = module(h, emb)
|
|
342
|
+
h = h.type(x.dtype)
|
|
343
|
+
return self.out(h)
|
|
344
|
+
|
|
345
|
+
def get_alpha_prod(
|
|
346
|
+
self, beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=1000
|
|
347
|
+
):
|
|
348
|
+
"""
|
|
349
|
+
Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising.
|
|
350
|
+
"""
|
|
351
|
+
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
|
352
|
+
betas = torch.from_numpy(
|
|
353
|
+
betas
|
|
354
|
+
) # .to(self.device) Removing this for now, can be done outside
|
|
355
|
+
alphas = 1.0 - betas
|
|
356
|
+
alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
|
|
357
|
+
|
|
358
|
+
# Useful sequences deriving from alphas_cumprod
|
|
359
|
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
|
360
|
+
sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
|
361
|
+
reduced_alpha_cumprod = torch.div(sqrt_1m_alphas_cumprod, sqrt_alphas_cumprod)
|
|
362
|
+
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
|
363
|
+
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
|
|
364
|
+
return (
|
|
365
|
+
reduced_alpha_cumprod,
|
|
366
|
+
sqrt_recip_alphas_cumprod,
|
|
367
|
+
sqrt_recipm1_alphas_cumprod,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def find_nearest(self, array, value):
|
|
371
|
+
"""
|
|
372
|
+
Find the argmin of the nearest value in an array.
|
|
373
|
+
"""
|
|
374
|
+
array = np.asarray(array)
|
|
375
|
+
if isinstance(value, torch.Tensor):
|
|
376
|
+
value = np.asarray(value.cpu())
|
|
377
|
+
idx = (np.abs(array - value)).argmin()
|
|
378
|
+
return idx
|
|
379
|
+
|
|
380
|
+
def forward_denoise(self, x, sigma, y=None):
|
|
381
|
+
r"""
|
|
382
|
+
Apply the denoising model to an input batch.
|
|
383
|
+
|
|
384
|
+
This function takes a noisy image and a noise level as input (and not a timestep) and estimates the noiseless
|
|
385
|
+
underlying image in the input image.
|
|
386
|
+
The input image is assumed to be in range [0, 1] (up to noise) and to have dimensions with width and height
|
|
387
|
+
divisible by a power of 2.
|
|
388
|
+
|
|
389
|
+
:param x: an [N x C x ...] Tensor of inputs.
|
|
390
|
+
:param sigma: a 1-D batch of noise levels.
|
|
391
|
+
:param y: an [N] Tensor of labels, if class-conditional. Default=None.
|
|
392
|
+
:return: an [N x C x ...] Tensor of outputs.
|
|
393
|
+
"""
|
|
394
|
+
x = 2.0 * x - 1.0
|
|
395
|
+
(
|
|
396
|
+
reduced_alpha_cumprod,
|
|
397
|
+
sqrt_recip_alphas_cumprod,
|
|
398
|
+
sqrt_recipm1_alphas_cumprod,
|
|
399
|
+
) = self.get_alpha_prod()
|
|
400
|
+
timesteps = self.find_nearest(
|
|
401
|
+
reduced_alpha_cumprod, sigma * 2
|
|
402
|
+
) # Factor 2 because image rescaled in [-1, 1]
|
|
403
|
+
|
|
404
|
+
noise_est_sample_var = self.forward_diffusion(
|
|
405
|
+
x, torch.tensor([timesteps]).to(x.device), y=y
|
|
406
|
+
)
|
|
407
|
+
noise_est = noise_est_sample_var[:, :3, ...]
|
|
408
|
+
denoised = (
|
|
409
|
+
sqrt_recip_alphas_cumprod[timesteps] * x
|
|
410
|
+
- sqrt_recipm1_alphas_cumprod[timesteps] * noise_est
|
|
411
|
+
)
|
|
412
|
+
denoised = denoised.clamp(-1, 1)
|
|
413
|
+
return denoised / 2.0 + 0.5
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class AttentionPool2d(nn.Module):
|
|
417
|
+
"""
|
|
418
|
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
def __init__(
|
|
422
|
+
self,
|
|
423
|
+
spacial_dim: int,
|
|
424
|
+
embed_dim: int,
|
|
425
|
+
num_heads_channels: int,
|
|
426
|
+
output_dim: int = None,
|
|
427
|
+
):
|
|
428
|
+
super().__init__()
|
|
429
|
+
self.positional_embedding = nn.Parameter(
|
|
430
|
+
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
|
431
|
+
)
|
|
432
|
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
433
|
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
434
|
+
self.num_heads = embed_dim // num_heads_channels
|
|
435
|
+
self.attention = QKVAttention(self.num_heads)
|
|
436
|
+
|
|
437
|
+
def forward(self, x):
|
|
438
|
+
b, c, *_spatial = x.shape
|
|
439
|
+
x = x.reshape(b, c, -1) # NC(HW)
|
|
440
|
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
441
|
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
442
|
+
x = self.qkv_proj(x)
|
|
443
|
+
x = self.attention(x)
|
|
444
|
+
x = self.c_proj(x)
|
|
445
|
+
return x[:, :, 0]
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class TimestepBlock(nn.Module):
|
|
449
|
+
"""
|
|
450
|
+
Any module where forward() takes timestep embeddings as a second argument.
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
@abstractmethod
|
|
454
|
+
def forward(self, x, emb):
|
|
455
|
+
"""
|
|
456
|
+
Apply the module to `x` given `emb` timestep embeddings.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
461
|
+
"""
|
|
462
|
+
A sequential module that passes timestep embeddings to the children that
|
|
463
|
+
support it as an extra input.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
def forward(self, x, emb):
|
|
467
|
+
for layer in self:
|
|
468
|
+
if isinstance(layer, TimestepBlock):
|
|
469
|
+
x = layer(x, emb)
|
|
470
|
+
else:
|
|
471
|
+
x = layer(x)
|
|
472
|
+
return x
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class Upsample(nn.Module):
|
|
476
|
+
"""
|
|
477
|
+
An upsampling layer with an optional convolution.
|
|
478
|
+
|
|
479
|
+
:param channels: channels in the inputs and outputs.
|
|
480
|
+
:param use_conv: a bool determining if a convolution is applied.
|
|
481
|
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
482
|
+
upsampling occurs in the inner-two dimensions.
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
486
|
+
super().__init__()
|
|
487
|
+
self.channels = channels
|
|
488
|
+
self.out_channels = out_channels or channels
|
|
489
|
+
self.use_conv = use_conv
|
|
490
|
+
self.dims = dims
|
|
491
|
+
if use_conv:
|
|
492
|
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
|
493
|
+
|
|
494
|
+
def forward(self, x):
|
|
495
|
+
assert x.shape[1] == self.channels
|
|
496
|
+
if self.dims == 3:
|
|
497
|
+
x = F.interpolate(
|
|
498
|
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
502
|
+
if self.use_conv:
|
|
503
|
+
x = self.conv(x)
|
|
504
|
+
return x
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class Downsample(nn.Module):
|
|
508
|
+
"""
|
|
509
|
+
A downsampling layer with an optional convolution.
|
|
510
|
+
|
|
511
|
+
:param channels: channels in the inputs and outputs.
|
|
512
|
+
:param use_conv: a bool determining if a convolution is applied.
|
|
513
|
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
514
|
+
downsampling occurs in the inner-two dimensions.
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
518
|
+
super().__init__()
|
|
519
|
+
self.channels = channels
|
|
520
|
+
self.out_channels = out_channels or channels
|
|
521
|
+
self.use_conv = use_conv
|
|
522
|
+
self.dims = dims
|
|
523
|
+
stride = 2 if dims != 3 else (1, 2, 2)
|
|
524
|
+
if use_conv:
|
|
525
|
+
self.op = conv_nd(
|
|
526
|
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
assert self.channels == self.out_channels
|
|
530
|
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
531
|
+
|
|
532
|
+
def forward(self, x):
|
|
533
|
+
assert x.shape[1] == self.channels
|
|
534
|
+
return self.op(x)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class ResBlock(TimestepBlock):
|
|
538
|
+
"""
|
|
539
|
+
A residual block that can optionally change the number of channels.
|
|
540
|
+
|
|
541
|
+
:param channels: the number of input channels.
|
|
542
|
+
:param emb_channels: the number of timestep embedding channels.
|
|
543
|
+
:param dropout: the rate of dropout.
|
|
544
|
+
:param out_channels: if specified, the number of out channels.
|
|
545
|
+
:param use_conv: if True and out_channels is specified, use a spatial
|
|
546
|
+
convolution instead of a smaller 1x1 convolution to change the
|
|
547
|
+
channels in the skip connection.
|
|
548
|
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
549
|
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
|
550
|
+
:param up: if True, use this block for upsampling.
|
|
551
|
+
:param down: if True, use this block for downsampling.
|
|
552
|
+
"""
|
|
553
|
+
|
|
554
|
+
def __init__(
|
|
555
|
+
self,
|
|
556
|
+
channels,
|
|
557
|
+
emb_channels,
|
|
558
|
+
dropout,
|
|
559
|
+
out_channels=None,
|
|
560
|
+
use_conv=False,
|
|
561
|
+
use_scale_shift_norm=False,
|
|
562
|
+
dims=2,
|
|
563
|
+
use_checkpoint=False,
|
|
564
|
+
up=False,
|
|
565
|
+
down=False,
|
|
566
|
+
):
|
|
567
|
+
super().__init__()
|
|
568
|
+
self.channels = channels
|
|
569
|
+
self.emb_channels = emb_channels
|
|
570
|
+
self.dropout = dropout
|
|
571
|
+
self.out_channels = out_channels or channels
|
|
572
|
+
self.use_conv = use_conv
|
|
573
|
+
self.use_checkpoint = use_checkpoint
|
|
574
|
+
self.use_scale_shift_norm = use_scale_shift_norm
|
|
575
|
+
|
|
576
|
+
self.in_layers = nn.Sequential(
|
|
577
|
+
normalization(channels),
|
|
578
|
+
nn.SiLU(),
|
|
579
|
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
self.updown = up or down
|
|
583
|
+
|
|
584
|
+
if up:
|
|
585
|
+
self.h_upd = Upsample(channels, False, dims)
|
|
586
|
+
self.x_upd = Upsample(channels, False, dims)
|
|
587
|
+
elif down:
|
|
588
|
+
self.h_upd = Downsample(channels, False, dims)
|
|
589
|
+
self.x_upd = Downsample(channels, False, dims)
|
|
590
|
+
else:
|
|
591
|
+
self.h_upd = self.x_upd = nn.Identity()
|
|
592
|
+
|
|
593
|
+
self.emb_layers = nn.Sequential(
|
|
594
|
+
nn.SiLU(),
|
|
595
|
+
linear(
|
|
596
|
+
emb_channels,
|
|
597
|
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
598
|
+
),
|
|
599
|
+
)
|
|
600
|
+
self.out_layers = nn.Sequential(
|
|
601
|
+
normalization(self.out_channels),
|
|
602
|
+
nn.SiLU(),
|
|
603
|
+
nn.Dropout(p=dropout),
|
|
604
|
+
zero_module(
|
|
605
|
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
|
606
|
+
),
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
if self.out_channels == channels:
|
|
610
|
+
self.skip_connection = nn.Identity()
|
|
611
|
+
elif use_conv:
|
|
612
|
+
self.skip_connection = conv_nd(
|
|
613
|
+
dims, channels, self.out_channels, 3, padding=1
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
617
|
+
|
|
618
|
+
def forward(self, x, emb):
|
|
619
|
+
"""
|
|
620
|
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
621
|
+
|
|
622
|
+
:param x: an [N x C x ...] Tensor of features.
|
|
623
|
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
|
624
|
+
:return: an [N x C x ...] Tensor of outputs.
|
|
625
|
+
"""
|
|
626
|
+
return checkpoint(
|
|
627
|
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
def _forward(self, x, emb):
|
|
631
|
+
if self.updown:
|
|
632
|
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
633
|
+
h = in_rest(x)
|
|
634
|
+
h = self.h_upd(h)
|
|
635
|
+
x = self.x_upd(x)
|
|
636
|
+
h = in_conv(h)
|
|
637
|
+
else:
|
|
638
|
+
h = self.in_layers(x)
|
|
639
|
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
|
640
|
+
while len(emb_out.shape) < len(h.shape):
|
|
641
|
+
emb_out = emb_out[..., None]
|
|
642
|
+
if self.use_scale_shift_norm:
|
|
643
|
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
644
|
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
|
645
|
+
h = out_norm(h) * (1 + scale) + shift
|
|
646
|
+
h = out_rest(h)
|
|
647
|
+
else:
|
|
648
|
+
h = h + emb_out
|
|
649
|
+
h = self.out_layers(h)
|
|
650
|
+
return self.skip_connection(x) + h
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class AttentionBlock(nn.Module):
|
|
654
|
+
"""
|
|
655
|
+
An attention block that allows spatial positions to attend to each other.
|
|
656
|
+
|
|
657
|
+
Originally ported from here, but adapted to the N-d case.
|
|
658
|
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
def __init__(
|
|
662
|
+
self,
|
|
663
|
+
channels,
|
|
664
|
+
num_heads=1,
|
|
665
|
+
num_head_channels=-1,
|
|
666
|
+
use_checkpoint=False,
|
|
667
|
+
use_new_attention_order=False,
|
|
668
|
+
):
|
|
669
|
+
super().__init__()
|
|
670
|
+
self.channels = channels
|
|
671
|
+
if num_head_channels == -1:
|
|
672
|
+
self.num_heads = num_heads
|
|
673
|
+
else:
|
|
674
|
+
assert (
|
|
675
|
+
channels % num_head_channels == 0
|
|
676
|
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
677
|
+
self.num_heads = channels // num_head_channels
|
|
678
|
+
self.use_checkpoint = use_checkpoint
|
|
679
|
+
self.norm = normalization(channels)
|
|
680
|
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
681
|
+
if use_new_attention_order:
|
|
682
|
+
# split qkv before split heads
|
|
683
|
+
self.attention = QKVAttention(self.num_heads)
|
|
684
|
+
else:
|
|
685
|
+
# split heads before split qkv
|
|
686
|
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
687
|
+
|
|
688
|
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
689
|
+
|
|
690
|
+
def forward(self, x):
|
|
691
|
+
return checkpoint(self._forward, (x,), self.parameters(), True)
|
|
692
|
+
|
|
693
|
+
def _forward(self, x):
|
|
694
|
+
b, c, *spatial = x.shape
|
|
695
|
+
x = x.reshape(b, c, -1)
|
|
696
|
+
qkv = self.qkv(self.norm(x))
|
|
697
|
+
h = self.attention(qkv)
|
|
698
|
+
h = self.proj_out(h)
|
|
699
|
+
return (x + h).reshape(b, c, *spatial)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def count_flops_attn(model, _x, y):
|
|
703
|
+
"""
|
|
704
|
+
A counter for the `thop` package to count the operations in an
|
|
705
|
+
attention operation.
|
|
706
|
+
Meant to be used like:
|
|
707
|
+
macs, params = thop.profile(
|
|
708
|
+
model,
|
|
709
|
+
inputs=(inputs, timestamps),
|
|
710
|
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
711
|
+
)
|
|
712
|
+
"""
|
|
713
|
+
b, c, *spatial = y[0].shape
|
|
714
|
+
num_spatial = int(np.prod(spatial))
|
|
715
|
+
# We perform two matmuls with the same number of ops.
|
|
716
|
+
# The first computes the weight matrix, the second computes
|
|
717
|
+
# the combination of the value vectors.
|
|
718
|
+
matmul_ops = 2 * b * (num_spatial**2) * c
|
|
719
|
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
class QKVAttentionLegacy(nn.Module):
|
|
723
|
+
"""
|
|
724
|
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
725
|
+
"""
|
|
726
|
+
|
|
727
|
+
def __init__(self, n_heads):
|
|
728
|
+
super().__init__()
|
|
729
|
+
self.n_heads = n_heads
|
|
730
|
+
|
|
731
|
+
def forward(self, qkv):
|
|
732
|
+
"""
|
|
733
|
+
Apply QKV attention.
|
|
734
|
+
|
|
735
|
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
736
|
+
:return: an [N x (H * C) x T] tensor after attention.
|
|
737
|
+
"""
|
|
738
|
+
bs, width, length = qkv.shape
|
|
739
|
+
assert width % (3 * self.n_heads) == 0
|
|
740
|
+
ch = width // (3 * self.n_heads)
|
|
741
|
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
742
|
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
743
|
+
weight = th.einsum(
|
|
744
|
+
"bct,bcs->bts", q * scale, k * scale
|
|
745
|
+
) # More stable with f16 than dividing afterwards
|
|
746
|
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
747
|
+
a = th.einsum("bts,bcs->bct", weight, v)
|
|
748
|
+
return a.reshape(bs, -1, length)
|
|
749
|
+
|
|
750
|
+
@staticmethod
|
|
751
|
+
def count_flops(model, _x, y):
|
|
752
|
+
return count_flops_attn(model, _x, y)
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
class QKVAttention(nn.Module):
|
|
756
|
+
"""
|
|
757
|
+
A module which performs QKV attention and splits in a different order.
|
|
758
|
+
"""
|
|
759
|
+
|
|
760
|
+
def __init__(self, n_heads):
|
|
761
|
+
super().__init__()
|
|
762
|
+
self.n_heads = n_heads
|
|
763
|
+
|
|
764
|
+
def forward(self, qkv):
|
|
765
|
+
"""
|
|
766
|
+
Apply QKV attention.
|
|
767
|
+
|
|
768
|
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
769
|
+
:return: an [N x (H * C) x T] tensor after attention.
|
|
770
|
+
"""
|
|
771
|
+
bs, width, length = qkv.shape
|
|
772
|
+
assert width % (3 * self.n_heads) == 0
|
|
773
|
+
ch = width // (3 * self.n_heads)
|
|
774
|
+
q, k, v = qkv.chunk(3, dim=1)
|
|
775
|
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
776
|
+
weight = th.einsum(
|
|
777
|
+
"bct,bcs->bts",
|
|
778
|
+
(q * scale).view(bs * self.n_heads, ch, length),
|
|
779
|
+
(k * scale).view(bs * self.n_heads, ch, length),
|
|
780
|
+
) # More stable with f16 than dividing afterwards
|
|
781
|
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
782
|
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
783
|
+
return a.reshape(bs, -1, length)
|
|
784
|
+
|
|
785
|
+
@staticmethod
|
|
786
|
+
def count_flops(model, _x, y):
|
|
787
|
+
return count_flops_attn(model, _x, y)
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
def checkpoint(func, inputs, params, flag):
|
|
791
|
+
"""
|
|
792
|
+
Evaluate a function without caching intermediate activations, allowing for
|
|
793
|
+
reduced memory at the expense of extra compute in the backward pass.
|
|
794
|
+
|
|
795
|
+
:param func: the function to evaluate.
|
|
796
|
+
:param inputs: the argument sequence to pass to `func`.
|
|
797
|
+
:param params: a sequence of parameters `func` depends on but does not
|
|
798
|
+
explicitly take as arguments.
|
|
799
|
+
:param flag: if False, disable gradient checkpointing.
|
|
800
|
+
"""
|
|
801
|
+
if flag:
|
|
802
|
+
args = tuple(inputs) + tuple(params)
|
|
803
|
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
|
804
|
+
else:
|
|
805
|
+
return func(*inputs)
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
"""
|
|
809
|
+
Various utilities for neural networks.
|
|
810
|
+
"""
|
|
811
|
+
|
|
812
|
+
import math
|
|
813
|
+
|
|
814
|
+
import torch as th
|
|
815
|
+
import torch.nn as nn
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
819
|
+
class SiLU(nn.Module):
|
|
820
|
+
def forward(self, x):
|
|
821
|
+
return x * th.sigmoid(x)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
class GroupNorm32(nn.GroupNorm):
|
|
825
|
+
def forward(self, x):
|
|
826
|
+
return super().forward(x.float()).type(x.dtype)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def conv_nd(dims, *args, **kwargs):
|
|
830
|
+
"""
|
|
831
|
+
Create a 1D, 2D, or 3D convolution module.
|
|
832
|
+
"""
|
|
833
|
+
if dims == 1:
|
|
834
|
+
return nn.Conv1d(*args, **kwargs)
|
|
835
|
+
elif dims == 2:
|
|
836
|
+
return nn.Conv2d(*args, **kwargs)
|
|
837
|
+
elif dims == 3:
|
|
838
|
+
return nn.Conv3d(*args, **kwargs)
|
|
839
|
+
raise ValueError(f"unsupported dimensions: {dims}")
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def linear(*args, **kwargs):
|
|
843
|
+
"""
|
|
844
|
+
Create a linear module.
|
|
845
|
+
"""
|
|
846
|
+
return nn.Linear(*args, **kwargs)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def avg_pool_nd(dims, *args, **kwargs):
|
|
850
|
+
"""
|
|
851
|
+
Create a 1D, 2D, or 3D average pooling module.
|
|
852
|
+
"""
|
|
853
|
+
if dims == 1:
|
|
854
|
+
return nn.AvgPool1d(*args, **kwargs)
|
|
855
|
+
elif dims == 2:
|
|
856
|
+
return nn.AvgPool2d(*args, **kwargs)
|
|
857
|
+
elif dims == 3:
|
|
858
|
+
return nn.AvgPool3d(*args, **kwargs)
|
|
859
|
+
raise ValueError(f"unsupported dimensions: {dims}")
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def update_ema(target_params, source_params, rate=0.99):
|
|
863
|
+
"""
|
|
864
|
+
Update target parameters to be closer to those of source parameters using
|
|
865
|
+
an exponential moving average.
|
|
866
|
+
|
|
867
|
+
:param target_params: the target parameter sequence.
|
|
868
|
+
:param source_params: the source parameter sequence.
|
|
869
|
+
:param rate: the EMA rate (closer to 1 means slower).
|
|
870
|
+
"""
|
|
871
|
+
for targ, src in zip(target_params, source_params):
|
|
872
|
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def zero_module(module):
|
|
876
|
+
"""
|
|
877
|
+
Zero out the parameters of a module and return it.
|
|
878
|
+
"""
|
|
879
|
+
for p in module.parameters():
|
|
880
|
+
p.detach().zero_()
|
|
881
|
+
return module
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def scale_module(module, scale):
|
|
885
|
+
"""
|
|
886
|
+
Scale the parameters of a module and return it.
|
|
887
|
+
"""
|
|
888
|
+
for p in module.parameters():
|
|
889
|
+
p.detach().mul_(scale)
|
|
890
|
+
return module
|
|
891
|
+
|
|
892
|
+
|
|
893
|
+
def mean_flat(tensor):
|
|
894
|
+
"""
|
|
895
|
+
Take the mean over all non-batch dimensions.
|
|
896
|
+
"""
|
|
897
|
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
def normalization(channels):
|
|
901
|
+
"""
|
|
902
|
+
Make a standard normalization layer.
|
|
903
|
+
|
|
904
|
+
:param channels: number of input channels.
|
|
905
|
+
:return: an nn.Module for normalization.
|
|
906
|
+
"""
|
|
907
|
+
return GroupNorm32(32, channels)
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
|
911
|
+
"""
|
|
912
|
+
Create sinusoidal timestep embeddings.
|
|
913
|
+
|
|
914
|
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
|
915
|
+
These may be fractional.
|
|
916
|
+
:param dim: the dimension of the output.
|
|
917
|
+
:param max_period: controls the minimum frequency of the embeddings.
|
|
918
|
+
:return: an [N x dim] Tensor of positional embeddings.
|
|
919
|
+
"""
|
|
920
|
+
half = dim // 2
|
|
921
|
+
freqs = th.exp(
|
|
922
|
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
|
923
|
+
).to(device=timesteps.device)
|
|
924
|
+
args = timesteps[:, None].float() * freqs[None]
|
|
925
|
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
|
926
|
+
if dim % 2:
|
|
927
|
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
|
928
|
+
return embedding
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def checkpoint(func, inputs, params, flag):
|
|
932
|
+
"""
|
|
933
|
+
Evaluate a function without caching intermediate activations, allowing for
|
|
934
|
+
reduced memory at the expense of extra compute in the backward pass.
|
|
935
|
+
|
|
936
|
+
:param func: the function to evaluate.
|
|
937
|
+
:param inputs: the argument sequence to pass to `func`.
|
|
938
|
+
:param params: a sequence of parameters `func` depends on but does not
|
|
939
|
+
explicitly take as arguments.
|
|
940
|
+
:param flag: if False, disable gradient checkpointing.
|
|
941
|
+
"""
|
|
942
|
+
if flag:
|
|
943
|
+
args = tuple(inputs) + tuple(params)
|
|
944
|
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
|
945
|
+
else:
|
|
946
|
+
return func(*inputs)
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
class CheckpointFunction(th.autograd.Function):
|
|
950
|
+
@staticmethod
|
|
951
|
+
def forward(ctx, run_function, length, *args):
|
|
952
|
+
ctx.run_function = run_function
|
|
953
|
+
ctx.input_tensors = list(args[:length])
|
|
954
|
+
ctx.input_params = list(args[length:])
|
|
955
|
+
with th.no_grad():
|
|
956
|
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
|
957
|
+
return output_tensors
|
|
958
|
+
|
|
959
|
+
@staticmethod
|
|
960
|
+
def backward(ctx, *output_grads):
|
|
961
|
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
|
962
|
+
with th.enable_grad():
|
|
963
|
+
# Fixes a bug where the first op in run_function modifies the
|
|
964
|
+
# Tensor storage in place, which is not allowed for detach()'d
|
|
965
|
+
# Tensors.
|
|
966
|
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
|
967
|
+
output_tensors = ctx.run_function(*shallow_copies)
|
|
968
|
+
input_grads = th.autograd.grad(
|
|
969
|
+
output_tensors,
|
|
970
|
+
ctx.input_tensors + ctx.input_params,
|
|
971
|
+
output_grads,
|
|
972
|
+
allow_unused=True,
|
|
973
|
+
)
|
|
974
|
+
del ctx.input_tensors
|
|
975
|
+
del ctx.input_params
|
|
976
|
+
del output_tensors
|
|
977
|
+
return (None, None) + input_grads
|
|
978
|
+
|
|
979
|
+
|
|
980
|
+
def convert_module_to_f16(l):
|
|
981
|
+
"""
|
|
982
|
+
Convert primitive modules to float16.
|
|
983
|
+
"""
|
|
984
|
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
985
|
+
l.weight.data = l.weight.data.half()
|
|
986
|
+
if l.bias is not None:
|
|
987
|
+
l.bias.data = l.bias.data.half()
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def convert_module_to_f32(l):
|
|
991
|
+
"""
|
|
992
|
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
|
993
|
+
"""
|
|
994
|
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
995
|
+
l.weight.data = l.weight.data.float()
|
|
996
|
+
if l.bias is not None:
|
|
997
|
+
l.bias.data = l.bias.data.float()
|