deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,997 @@
1
+ # This file is a concatenation of DiffPIR codes available here: https://github.com/yuanzhi-zhu/DiffPIR/tree/main
2
+ # This code is taken (with minor modifications) from https://github.com/yuanzhi-zhu/DiffPIR/tree/main
3
+
4
+ import torch
5
+ from .utils import get_weights_url
6
+ from abc import abstractmethod
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class DiffUNet(nn.Module):
13
+ r"""
14
+ Diffusion UNet model.
15
+
16
+ This is the model with attention and timestep embeddings from `Ho et al. <https://arxiv.org/abs/2108.02938>`_;
17
+ code is adapted from https://github.com/jychoi118/ilvr_adm.
18
+
19
+ It is possible to choose the `standard model <https://arxiv.org/abs/2108.02938>`_
20
+ with 128 hidden channels per layer (trained on FFHQ)
21
+ and a `larger model <https://arxiv.org/abs/2105.05233>`_ with 256 hidden channels per layer (trained on ImageNet128)
22
+
23
+ A pretrained network for (in_channels=out_channels=3)
24
+ can be downloaded via setting ``pretrained='download'``.
25
+
26
+ The network can handle images of size :math:`2^{n_1}\times 2^{n_2}` with :math:`n_1,n_2 \geq 5`.
27
+
28
+
29
+ :param int in_channels: channels in the input Tensor.
30
+ :param int out_channels: channels in the output Tensor.
31
+ :param bool large_model: if True, use the large model with 256 hidden channels per layer trained on ImageNet128
32
+ (weights size: 2.1 GB).
33
+ Otherwise, use a smaller model with 128 hidden channels per layer trained on FFHQ (weights size: 357 MB).
34
+ :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at
35
+ random using Pytorch's default initialization.
36
+ If ``pretrained='download'``, the weights will be downloaded from an online repository
37
+ (only available for 3 input and output channels).
38
+ Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
39
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ in_channels=3,
45
+ out_channels=3,
46
+ large_model=False,
47
+ use_fp16=False,
48
+ pretrained="download",
49
+ ):
50
+ super().__init__()
51
+
52
+ if large_model:
53
+ model_channels = 256
54
+ num_res_blocks = 2
55
+ attention_resolutions = "8,16,32"
56
+ else:
57
+ model_channels = 128
58
+ num_res_blocks = 1
59
+ attention_resolutions = "16"
60
+
61
+ dropout = 0.1
62
+ conv_resample = True
63
+ dims = 2
64
+ num_classes = None
65
+ use_checkpoint = False
66
+ num_heads = 4
67
+ num_head_channels = 64
68
+ num_heads_upsample = -1
69
+ use_scale_shift_norm = True
70
+ resblock_updown = True
71
+ use_new_attention_order = False
72
+
73
+ out_channels = 6 if out_channels == 3 else out_channels
74
+ channel_mult = (1, 1, 2, 2, 4, 4)
75
+
76
+ image_size = 256
77
+ attention_ds = []
78
+ for res in attention_resolutions.split(","):
79
+ attention_ds.append(image_size // int(res))
80
+ attention_resolutions = tuple(attention_ds)
81
+
82
+ if num_heads_upsample == -1:
83
+ num_heads_upsample = num_heads
84
+
85
+ self.image_size = image_size
86
+ self.in_channels = in_channels
87
+ self.model_channels = model_channels
88
+ self.out_channels = out_channels
89
+ self.num_res_blocks = num_res_blocks
90
+ self.attention_resolutions = attention_resolutions
91
+ self.dropout = dropout
92
+ self.channel_mult = channel_mult
93
+ self.conv_resample = conv_resample
94
+ self.num_classes = num_classes
95
+ self.use_checkpoint = use_checkpoint
96
+ self.dtype = th.float16 if use_fp16 else th.float32
97
+ self.num_heads = num_heads
98
+ self.num_head_channels = num_head_channels
99
+ self.num_heads_upsample = num_heads_upsample
100
+
101
+ time_embed_dim = model_channels * 4
102
+ self.time_embed = nn.Sequential(
103
+ linear(model_channels, time_embed_dim),
104
+ nn.SiLU(),
105
+ linear(time_embed_dim, time_embed_dim),
106
+ )
107
+
108
+ if self.num_classes is not None:
109
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
110
+
111
+ ch = input_ch = int(channel_mult[0] * model_channels)
112
+ self.input_blocks = nn.ModuleList(
113
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
114
+ )
115
+ self._feature_size = ch
116
+ input_block_chans = [ch]
117
+ ds = 1
118
+ for level, mult in enumerate(channel_mult):
119
+ for _ in range(num_res_blocks):
120
+ layers = [
121
+ ResBlock(
122
+ ch,
123
+ time_embed_dim,
124
+ dropout,
125
+ out_channels=int(mult * model_channels),
126
+ dims=dims,
127
+ use_checkpoint=use_checkpoint,
128
+ use_scale_shift_norm=use_scale_shift_norm,
129
+ )
130
+ ]
131
+ ch = int(mult * model_channels)
132
+ if ds in attention_resolutions:
133
+ layers.append(
134
+ AttentionBlock(
135
+ ch,
136
+ use_checkpoint=use_checkpoint,
137
+ num_heads=num_heads,
138
+ num_head_channels=num_head_channels,
139
+ use_new_attention_order=use_new_attention_order,
140
+ )
141
+ )
142
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
143
+ self._feature_size += ch
144
+ input_block_chans.append(ch)
145
+ if level != len(channel_mult) - 1:
146
+ out_ch = ch
147
+ self.input_blocks.append(
148
+ TimestepEmbedSequential(
149
+ ResBlock(
150
+ ch,
151
+ time_embed_dim,
152
+ dropout,
153
+ out_channels=out_ch,
154
+ dims=dims,
155
+ use_checkpoint=use_checkpoint,
156
+ use_scale_shift_norm=use_scale_shift_norm,
157
+ down=True,
158
+ )
159
+ if resblock_updown
160
+ else Downsample(
161
+ ch, conv_resample, dims=dims, out_channels=out_ch
162
+ )
163
+ )
164
+ )
165
+ ch = out_ch
166
+ input_block_chans.append(ch)
167
+ ds *= 2
168
+ self._feature_size += ch
169
+
170
+ self.middle_block = TimestepEmbedSequential(
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ dims=dims,
176
+ use_checkpoint=use_checkpoint,
177
+ use_scale_shift_norm=use_scale_shift_norm,
178
+ ),
179
+ AttentionBlock(
180
+ ch,
181
+ use_checkpoint=use_checkpoint,
182
+ num_heads=num_heads,
183
+ num_head_channels=num_head_channels,
184
+ use_new_attention_order=use_new_attention_order,
185
+ ),
186
+ ResBlock(
187
+ ch,
188
+ time_embed_dim,
189
+ dropout,
190
+ dims=dims,
191
+ use_checkpoint=use_checkpoint,
192
+ use_scale_shift_norm=use_scale_shift_norm,
193
+ ),
194
+ )
195
+ self._feature_size += ch
196
+
197
+ self.output_blocks = nn.ModuleList([])
198
+ for level, mult in list(enumerate(channel_mult))[::-1]:
199
+ for i in range(num_res_blocks + 1):
200
+ ich = input_block_chans.pop()
201
+ layers = [
202
+ ResBlock(
203
+ ch + ich,
204
+ time_embed_dim,
205
+ dropout,
206
+ out_channels=int(model_channels * mult),
207
+ dims=dims,
208
+ use_checkpoint=use_checkpoint,
209
+ use_scale_shift_norm=use_scale_shift_norm,
210
+ )
211
+ ]
212
+ ch = int(model_channels * mult)
213
+ if ds in attention_resolutions:
214
+ layers.append(
215
+ AttentionBlock(
216
+ ch,
217
+ use_checkpoint=use_checkpoint,
218
+ num_heads=num_heads_upsample,
219
+ num_head_channels=num_head_channels,
220
+ use_new_attention_order=use_new_attention_order,
221
+ )
222
+ )
223
+ if level and i == num_res_blocks:
224
+ out_ch = ch
225
+ layers.append(
226
+ ResBlock(
227
+ ch,
228
+ time_embed_dim,
229
+ dropout,
230
+ out_channels=out_ch,
231
+ dims=dims,
232
+ use_checkpoint=use_checkpoint,
233
+ use_scale_shift_norm=use_scale_shift_norm,
234
+ up=True,
235
+ )
236
+ if resblock_updown
237
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
238
+ )
239
+ ds //= 2
240
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
241
+ self._feature_size += ch
242
+
243
+ self.out = nn.Sequential(
244
+ normalization(ch),
245
+ nn.SiLU(),
246
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
247
+ )
248
+
249
+ if pretrained is not None:
250
+ if pretrained == "download":
251
+ if in_channels == 3 and out_channels == 6 and not large_model:
252
+ name = "diffusion_ffhq_10m.pt"
253
+ elif in_channels == 3 and out_channels == 6 and large_model:
254
+ name = "diffusion_openai.pt"
255
+ else:
256
+ raise ValueError(
257
+ "no existing pretrained model matches the requested configuration"
258
+ )
259
+ url = get_weights_url(model_name="diffunet", file_name=name)
260
+ ckpt = torch.hub.load_state_dict_from_url(
261
+ url, map_location=lambda storage, loc: storage, file_name=name
262
+ )
263
+ else:
264
+ ckpt = torch.load(pretrained, map_location=lambda storage, loc: storage)
265
+
266
+ self.load_state_dict(ckpt, strict=True)
267
+
268
+ def convert_to_fp16(self):
269
+ """
270
+ Convert the torso of the model to float16.
271
+ """
272
+ self.input_blocks.apply(convert_module_to_f16)
273
+ self.middle_block.apply(convert_module_to_f16)
274
+ self.output_blocks.apply(convert_module_to_f16)
275
+
276
+ def convert_to_fp32(self):
277
+ """
278
+ Convert the torso of the model to float32.
279
+ """
280
+ self.input_blocks.apply(convert_module_to_f32)
281
+ self.middle_block.apply(convert_module_to_f32)
282
+ self.output_blocks.apply(convert_module_to_f32)
283
+
284
+ def forward(self, x, t, y=None, type_t="noise_level"):
285
+ r"""
286
+ Apply the model to an input batch.
287
+
288
+ This function takes a noisy image and either a timestep or a noise level as input. Depending on the nature of
289
+ ``t``, the model returns either a noise map (if ``type_t='timestep'``) or a denoised image (if
290
+ ``type_t='noise_level'``).
291
+
292
+ :param x: an [N x C x ...] Tensor of inputs.
293
+ :param t: a 1-D batch of timesteps or noise levels.
294
+ :param y: an [N] Tensor of labels, if class-conditional. Default=None.
295
+ :param type_t: Nature of the embedding `t`. In traditional diffusion model, and in the authors' code, `t` is
296
+ a timestep linked to a noise level; in this case, set ``type_t='timestep'``. We can also choose
297
+ ``t`` to be a noise level directly and use the model as a denoiser; in this case, set
298
+ ``type_t='noise_level'``. Default: ``'timestep'``.
299
+ :return: an [N x C x ...] Tensor of outputs. Either a noise map (if ``type_t='timestep'``) or a denoised image
300
+ (if ``type_t='noise_level'``).
301
+ """
302
+ if type_t == "timestep":
303
+ return self.forward_diffusion(x, t, y=y)
304
+ elif type_t == "noise_level":
305
+ return self.forward_denoise(x, t, y=y)
306
+ else:
307
+ raise ValueError('type_t must be either "timestep" or "noise_level"')
308
+
309
+ def forward_diffusion(self, x, timesteps, y=None):
310
+ r"""
311
+ Apply the model to an input batch.
312
+
313
+ This function takes a noisy image and a timestep as input (and not a noise level) and estimates the noise map
314
+ in the input image.
315
+ The image is assumed to be in range [-1, 1] and to have dimensions with width and height divisible by a
316
+ power of 2.
317
+
318
+ :param x: an [N x C x ...] Tensor of inputs.
319
+ :param timesteps: a 1-D batch of timesteps.
320
+ :param y: an [N] Tensor of labels, if class-conditional. Default=None.
321
+ :return: an [N x C x ...] Tensor of outputs.
322
+ """
323
+ assert (y is not None) == (
324
+ self.num_classes is not None
325
+ ), "must specify y if and only if the model is class-conditional"
326
+
327
+ hs = []
328
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
329
+
330
+ if self.num_classes is not None:
331
+ assert y.shape == (x.shape[0],)
332
+ emb = emb + self.label_emb(y)
333
+
334
+ h = x.type(self.dtype)
335
+ for module in self.input_blocks:
336
+ h = module(h, emb)
337
+ hs.append(h)
338
+ h = self.middle_block(h, emb)
339
+ for module in self.output_blocks:
340
+ h = th.cat([h, hs.pop()], dim=1)
341
+ h = module(h, emb)
342
+ h = h.type(x.dtype)
343
+ return self.out(h)
344
+
345
+ def get_alpha_prod(
346
+ self, beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=1000
347
+ ):
348
+ """
349
+ Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising.
350
+ """
351
+ betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
352
+ betas = torch.from_numpy(
353
+ betas
354
+ ) # .to(self.device) Removing this for now, can be done outside
355
+ alphas = 1.0 - betas
356
+ alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
357
+
358
+ # Useful sequences deriving from alphas_cumprod
359
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
360
+ sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
361
+ reduced_alpha_cumprod = torch.div(sqrt_1m_alphas_cumprod, sqrt_alphas_cumprod)
362
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
363
+ sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
364
+ return (
365
+ reduced_alpha_cumprod,
366
+ sqrt_recip_alphas_cumprod,
367
+ sqrt_recipm1_alphas_cumprod,
368
+ )
369
+
370
+ def find_nearest(self, array, value):
371
+ """
372
+ Find the argmin of the nearest value in an array.
373
+ """
374
+ array = np.asarray(array)
375
+ if isinstance(value, torch.Tensor):
376
+ value = np.asarray(value.cpu())
377
+ idx = (np.abs(array - value)).argmin()
378
+ return idx
379
+
380
+ def forward_denoise(self, x, sigma, y=None):
381
+ r"""
382
+ Apply the denoising model to an input batch.
383
+
384
+ This function takes a noisy image and a noise level as input (and not a timestep) and estimates the noiseless
385
+ underlying image in the input image.
386
+ The input image is assumed to be in range [0, 1] (up to noise) and to have dimensions with width and height
387
+ divisible by a power of 2.
388
+
389
+ :param x: an [N x C x ...] Tensor of inputs.
390
+ :param sigma: a 1-D batch of noise levels.
391
+ :param y: an [N] Tensor of labels, if class-conditional. Default=None.
392
+ :return: an [N x C x ...] Tensor of outputs.
393
+ """
394
+ x = 2.0 * x - 1.0
395
+ (
396
+ reduced_alpha_cumprod,
397
+ sqrt_recip_alphas_cumprod,
398
+ sqrt_recipm1_alphas_cumprod,
399
+ ) = self.get_alpha_prod()
400
+ timesteps = self.find_nearest(
401
+ reduced_alpha_cumprod, sigma * 2
402
+ ) # Factor 2 because image rescaled in [-1, 1]
403
+
404
+ noise_est_sample_var = self.forward_diffusion(
405
+ x, torch.tensor([timesteps]).to(x.device), y=y
406
+ )
407
+ noise_est = noise_est_sample_var[:, :3, ...]
408
+ denoised = (
409
+ sqrt_recip_alphas_cumprod[timesteps] * x
410
+ - sqrt_recipm1_alphas_cumprod[timesteps] * noise_est
411
+ )
412
+ denoised = denoised.clamp(-1, 1)
413
+ return denoised / 2.0 + 0.5
414
+
415
+
416
+ class AttentionPool2d(nn.Module):
417
+ """
418
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ spacial_dim: int,
424
+ embed_dim: int,
425
+ num_heads_channels: int,
426
+ output_dim: int = None,
427
+ ):
428
+ super().__init__()
429
+ self.positional_embedding = nn.Parameter(
430
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
431
+ )
432
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
433
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
434
+ self.num_heads = embed_dim // num_heads_channels
435
+ self.attention = QKVAttention(self.num_heads)
436
+
437
+ def forward(self, x):
438
+ b, c, *_spatial = x.shape
439
+ x = x.reshape(b, c, -1) # NC(HW)
440
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
441
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
442
+ x = self.qkv_proj(x)
443
+ x = self.attention(x)
444
+ x = self.c_proj(x)
445
+ return x[:, :, 0]
446
+
447
+
448
+ class TimestepBlock(nn.Module):
449
+ """
450
+ Any module where forward() takes timestep embeddings as a second argument.
451
+ """
452
+
453
+ @abstractmethod
454
+ def forward(self, x, emb):
455
+ """
456
+ Apply the module to `x` given `emb` timestep embeddings.
457
+ """
458
+
459
+
460
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
461
+ """
462
+ A sequential module that passes timestep embeddings to the children that
463
+ support it as an extra input.
464
+ """
465
+
466
+ def forward(self, x, emb):
467
+ for layer in self:
468
+ if isinstance(layer, TimestepBlock):
469
+ x = layer(x, emb)
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ class Upsample(nn.Module):
476
+ """
477
+ An upsampling layer with an optional convolution.
478
+
479
+ :param channels: channels in the inputs and outputs.
480
+ :param use_conv: a bool determining if a convolution is applied.
481
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
482
+ upsampling occurs in the inner-two dimensions.
483
+ """
484
+
485
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
486
+ super().__init__()
487
+ self.channels = channels
488
+ self.out_channels = out_channels or channels
489
+ self.use_conv = use_conv
490
+ self.dims = dims
491
+ if use_conv:
492
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
493
+
494
+ def forward(self, x):
495
+ assert x.shape[1] == self.channels
496
+ if self.dims == 3:
497
+ x = F.interpolate(
498
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
499
+ )
500
+ else:
501
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
502
+ if self.use_conv:
503
+ x = self.conv(x)
504
+ return x
505
+
506
+
507
+ class Downsample(nn.Module):
508
+ """
509
+ A downsampling layer with an optional convolution.
510
+
511
+ :param channels: channels in the inputs and outputs.
512
+ :param use_conv: a bool determining if a convolution is applied.
513
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
514
+ downsampling occurs in the inner-two dimensions.
515
+ """
516
+
517
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
518
+ super().__init__()
519
+ self.channels = channels
520
+ self.out_channels = out_channels or channels
521
+ self.use_conv = use_conv
522
+ self.dims = dims
523
+ stride = 2 if dims != 3 else (1, 2, 2)
524
+ if use_conv:
525
+ self.op = conv_nd(
526
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
527
+ )
528
+ else:
529
+ assert self.channels == self.out_channels
530
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
531
+
532
+ def forward(self, x):
533
+ assert x.shape[1] == self.channels
534
+ return self.op(x)
535
+
536
+
537
+ class ResBlock(TimestepBlock):
538
+ """
539
+ A residual block that can optionally change the number of channels.
540
+
541
+ :param channels: the number of input channels.
542
+ :param emb_channels: the number of timestep embedding channels.
543
+ :param dropout: the rate of dropout.
544
+ :param out_channels: if specified, the number of out channels.
545
+ :param use_conv: if True and out_channels is specified, use a spatial
546
+ convolution instead of a smaller 1x1 convolution to change the
547
+ channels in the skip connection.
548
+ :param dims: determines if the signal is 1D, 2D, or 3D.
549
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
550
+ :param up: if True, use this block for upsampling.
551
+ :param down: if True, use this block for downsampling.
552
+ """
553
+
554
+ def __init__(
555
+ self,
556
+ channels,
557
+ emb_channels,
558
+ dropout,
559
+ out_channels=None,
560
+ use_conv=False,
561
+ use_scale_shift_norm=False,
562
+ dims=2,
563
+ use_checkpoint=False,
564
+ up=False,
565
+ down=False,
566
+ ):
567
+ super().__init__()
568
+ self.channels = channels
569
+ self.emb_channels = emb_channels
570
+ self.dropout = dropout
571
+ self.out_channels = out_channels or channels
572
+ self.use_conv = use_conv
573
+ self.use_checkpoint = use_checkpoint
574
+ self.use_scale_shift_norm = use_scale_shift_norm
575
+
576
+ self.in_layers = nn.Sequential(
577
+ normalization(channels),
578
+ nn.SiLU(),
579
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
580
+ )
581
+
582
+ self.updown = up or down
583
+
584
+ if up:
585
+ self.h_upd = Upsample(channels, False, dims)
586
+ self.x_upd = Upsample(channels, False, dims)
587
+ elif down:
588
+ self.h_upd = Downsample(channels, False, dims)
589
+ self.x_upd = Downsample(channels, False, dims)
590
+ else:
591
+ self.h_upd = self.x_upd = nn.Identity()
592
+
593
+ self.emb_layers = nn.Sequential(
594
+ nn.SiLU(),
595
+ linear(
596
+ emb_channels,
597
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
598
+ ),
599
+ )
600
+ self.out_layers = nn.Sequential(
601
+ normalization(self.out_channels),
602
+ nn.SiLU(),
603
+ nn.Dropout(p=dropout),
604
+ zero_module(
605
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
606
+ ),
607
+ )
608
+
609
+ if self.out_channels == channels:
610
+ self.skip_connection = nn.Identity()
611
+ elif use_conv:
612
+ self.skip_connection = conv_nd(
613
+ dims, channels, self.out_channels, 3, padding=1
614
+ )
615
+ else:
616
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
617
+
618
+ def forward(self, x, emb):
619
+ """
620
+ Apply the block to a Tensor, conditioned on a timestep embedding.
621
+
622
+ :param x: an [N x C x ...] Tensor of features.
623
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
624
+ :return: an [N x C x ...] Tensor of outputs.
625
+ """
626
+ return checkpoint(
627
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
628
+ )
629
+
630
+ def _forward(self, x, emb):
631
+ if self.updown:
632
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
633
+ h = in_rest(x)
634
+ h = self.h_upd(h)
635
+ x = self.x_upd(x)
636
+ h = in_conv(h)
637
+ else:
638
+ h = self.in_layers(x)
639
+ emb_out = self.emb_layers(emb).type(h.dtype)
640
+ while len(emb_out.shape) < len(h.shape):
641
+ emb_out = emb_out[..., None]
642
+ if self.use_scale_shift_norm:
643
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
644
+ scale, shift = th.chunk(emb_out, 2, dim=1)
645
+ h = out_norm(h) * (1 + scale) + shift
646
+ h = out_rest(h)
647
+ else:
648
+ h = h + emb_out
649
+ h = self.out_layers(h)
650
+ return self.skip_connection(x) + h
651
+
652
+
653
+ class AttentionBlock(nn.Module):
654
+ """
655
+ An attention block that allows spatial positions to attend to each other.
656
+
657
+ Originally ported from here, but adapted to the N-d case.
658
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
659
+ """
660
+
661
+ def __init__(
662
+ self,
663
+ channels,
664
+ num_heads=1,
665
+ num_head_channels=-1,
666
+ use_checkpoint=False,
667
+ use_new_attention_order=False,
668
+ ):
669
+ super().__init__()
670
+ self.channels = channels
671
+ if num_head_channels == -1:
672
+ self.num_heads = num_heads
673
+ else:
674
+ assert (
675
+ channels % num_head_channels == 0
676
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
677
+ self.num_heads = channels // num_head_channels
678
+ self.use_checkpoint = use_checkpoint
679
+ self.norm = normalization(channels)
680
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
681
+ if use_new_attention_order:
682
+ # split qkv before split heads
683
+ self.attention = QKVAttention(self.num_heads)
684
+ else:
685
+ # split heads before split qkv
686
+ self.attention = QKVAttentionLegacy(self.num_heads)
687
+
688
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
689
+
690
+ def forward(self, x):
691
+ return checkpoint(self._forward, (x,), self.parameters(), True)
692
+
693
+ def _forward(self, x):
694
+ b, c, *spatial = x.shape
695
+ x = x.reshape(b, c, -1)
696
+ qkv = self.qkv(self.norm(x))
697
+ h = self.attention(qkv)
698
+ h = self.proj_out(h)
699
+ return (x + h).reshape(b, c, *spatial)
700
+
701
+
702
+ def count_flops_attn(model, _x, y):
703
+ """
704
+ A counter for the `thop` package to count the operations in an
705
+ attention operation.
706
+ Meant to be used like:
707
+ macs, params = thop.profile(
708
+ model,
709
+ inputs=(inputs, timestamps),
710
+ custom_ops={QKVAttention: QKVAttention.count_flops},
711
+ )
712
+ """
713
+ b, c, *spatial = y[0].shape
714
+ num_spatial = int(np.prod(spatial))
715
+ # We perform two matmuls with the same number of ops.
716
+ # The first computes the weight matrix, the second computes
717
+ # the combination of the value vectors.
718
+ matmul_ops = 2 * b * (num_spatial**2) * c
719
+ model.total_ops += th.DoubleTensor([matmul_ops])
720
+
721
+
722
+ class QKVAttentionLegacy(nn.Module):
723
+ """
724
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
725
+ """
726
+
727
+ def __init__(self, n_heads):
728
+ super().__init__()
729
+ self.n_heads = n_heads
730
+
731
+ def forward(self, qkv):
732
+ """
733
+ Apply QKV attention.
734
+
735
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
736
+ :return: an [N x (H * C) x T] tensor after attention.
737
+ """
738
+ bs, width, length = qkv.shape
739
+ assert width % (3 * self.n_heads) == 0
740
+ ch = width // (3 * self.n_heads)
741
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
742
+ scale = 1 / math.sqrt(math.sqrt(ch))
743
+ weight = th.einsum(
744
+ "bct,bcs->bts", q * scale, k * scale
745
+ ) # More stable with f16 than dividing afterwards
746
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
747
+ a = th.einsum("bts,bcs->bct", weight, v)
748
+ return a.reshape(bs, -1, length)
749
+
750
+ @staticmethod
751
+ def count_flops(model, _x, y):
752
+ return count_flops_attn(model, _x, y)
753
+
754
+
755
+ class QKVAttention(nn.Module):
756
+ """
757
+ A module which performs QKV attention and splits in a different order.
758
+ """
759
+
760
+ def __init__(self, n_heads):
761
+ super().__init__()
762
+ self.n_heads = n_heads
763
+
764
+ def forward(self, qkv):
765
+ """
766
+ Apply QKV attention.
767
+
768
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
769
+ :return: an [N x (H * C) x T] tensor after attention.
770
+ """
771
+ bs, width, length = qkv.shape
772
+ assert width % (3 * self.n_heads) == 0
773
+ ch = width // (3 * self.n_heads)
774
+ q, k, v = qkv.chunk(3, dim=1)
775
+ scale = 1 / math.sqrt(math.sqrt(ch))
776
+ weight = th.einsum(
777
+ "bct,bcs->bts",
778
+ (q * scale).view(bs * self.n_heads, ch, length),
779
+ (k * scale).view(bs * self.n_heads, ch, length),
780
+ ) # More stable with f16 than dividing afterwards
781
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
782
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
783
+ return a.reshape(bs, -1, length)
784
+
785
+ @staticmethod
786
+ def count_flops(model, _x, y):
787
+ return count_flops_attn(model, _x, y)
788
+
789
+
790
+ def checkpoint(func, inputs, params, flag):
791
+ """
792
+ Evaluate a function without caching intermediate activations, allowing for
793
+ reduced memory at the expense of extra compute in the backward pass.
794
+
795
+ :param func: the function to evaluate.
796
+ :param inputs: the argument sequence to pass to `func`.
797
+ :param params: a sequence of parameters `func` depends on but does not
798
+ explicitly take as arguments.
799
+ :param flag: if False, disable gradient checkpointing.
800
+ """
801
+ if flag:
802
+ args = tuple(inputs) + tuple(params)
803
+ return CheckpointFunction.apply(func, len(inputs), *args)
804
+ else:
805
+ return func(*inputs)
806
+
807
+
808
+ """
809
+ Various utilities for neural networks.
810
+ """
811
+
812
+ import math
813
+
814
+ import torch as th
815
+ import torch.nn as nn
816
+
817
+
818
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
819
+ class SiLU(nn.Module):
820
+ def forward(self, x):
821
+ return x * th.sigmoid(x)
822
+
823
+
824
+ class GroupNorm32(nn.GroupNorm):
825
+ def forward(self, x):
826
+ return super().forward(x.float()).type(x.dtype)
827
+
828
+
829
+ def conv_nd(dims, *args, **kwargs):
830
+ """
831
+ Create a 1D, 2D, or 3D convolution module.
832
+ """
833
+ if dims == 1:
834
+ return nn.Conv1d(*args, **kwargs)
835
+ elif dims == 2:
836
+ return nn.Conv2d(*args, **kwargs)
837
+ elif dims == 3:
838
+ return nn.Conv3d(*args, **kwargs)
839
+ raise ValueError(f"unsupported dimensions: {dims}")
840
+
841
+
842
+ def linear(*args, **kwargs):
843
+ """
844
+ Create a linear module.
845
+ """
846
+ return nn.Linear(*args, **kwargs)
847
+
848
+
849
+ def avg_pool_nd(dims, *args, **kwargs):
850
+ """
851
+ Create a 1D, 2D, or 3D average pooling module.
852
+ """
853
+ if dims == 1:
854
+ return nn.AvgPool1d(*args, **kwargs)
855
+ elif dims == 2:
856
+ return nn.AvgPool2d(*args, **kwargs)
857
+ elif dims == 3:
858
+ return nn.AvgPool3d(*args, **kwargs)
859
+ raise ValueError(f"unsupported dimensions: {dims}")
860
+
861
+
862
+ def update_ema(target_params, source_params, rate=0.99):
863
+ """
864
+ Update target parameters to be closer to those of source parameters using
865
+ an exponential moving average.
866
+
867
+ :param target_params: the target parameter sequence.
868
+ :param source_params: the source parameter sequence.
869
+ :param rate: the EMA rate (closer to 1 means slower).
870
+ """
871
+ for targ, src in zip(target_params, source_params):
872
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
873
+
874
+
875
+ def zero_module(module):
876
+ """
877
+ Zero out the parameters of a module and return it.
878
+ """
879
+ for p in module.parameters():
880
+ p.detach().zero_()
881
+ return module
882
+
883
+
884
+ def scale_module(module, scale):
885
+ """
886
+ Scale the parameters of a module and return it.
887
+ """
888
+ for p in module.parameters():
889
+ p.detach().mul_(scale)
890
+ return module
891
+
892
+
893
+ def mean_flat(tensor):
894
+ """
895
+ Take the mean over all non-batch dimensions.
896
+ """
897
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
898
+
899
+
900
+ def normalization(channels):
901
+ """
902
+ Make a standard normalization layer.
903
+
904
+ :param channels: number of input channels.
905
+ :return: an nn.Module for normalization.
906
+ """
907
+ return GroupNorm32(32, channels)
908
+
909
+
910
+ def timestep_embedding(timesteps, dim, max_period=10000):
911
+ """
912
+ Create sinusoidal timestep embeddings.
913
+
914
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
915
+ These may be fractional.
916
+ :param dim: the dimension of the output.
917
+ :param max_period: controls the minimum frequency of the embeddings.
918
+ :return: an [N x dim] Tensor of positional embeddings.
919
+ """
920
+ half = dim // 2
921
+ freqs = th.exp(
922
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
923
+ ).to(device=timesteps.device)
924
+ args = timesteps[:, None].float() * freqs[None]
925
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
926
+ if dim % 2:
927
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
928
+ return embedding
929
+
930
+
931
+ def checkpoint(func, inputs, params, flag):
932
+ """
933
+ Evaluate a function without caching intermediate activations, allowing for
934
+ reduced memory at the expense of extra compute in the backward pass.
935
+
936
+ :param func: the function to evaluate.
937
+ :param inputs: the argument sequence to pass to `func`.
938
+ :param params: a sequence of parameters `func` depends on but does not
939
+ explicitly take as arguments.
940
+ :param flag: if False, disable gradient checkpointing.
941
+ """
942
+ if flag:
943
+ args = tuple(inputs) + tuple(params)
944
+ return CheckpointFunction.apply(func, len(inputs), *args)
945
+ else:
946
+ return func(*inputs)
947
+
948
+
949
+ class CheckpointFunction(th.autograd.Function):
950
+ @staticmethod
951
+ def forward(ctx, run_function, length, *args):
952
+ ctx.run_function = run_function
953
+ ctx.input_tensors = list(args[:length])
954
+ ctx.input_params = list(args[length:])
955
+ with th.no_grad():
956
+ output_tensors = ctx.run_function(*ctx.input_tensors)
957
+ return output_tensors
958
+
959
+ @staticmethod
960
+ def backward(ctx, *output_grads):
961
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
962
+ with th.enable_grad():
963
+ # Fixes a bug where the first op in run_function modifies the
964
+ # Tensor storage in place, which is not allowed for detach()'d
965
+ # Tensors.
966
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
967
+ output_tensors = ctx.run_function(*shallow_copies)
968
+ input_grads = th.autograd.grad(
969
+ output_tensors,
970
+ ctx.input_tensors + ctx.input_params,
971
+ output_grads,
972
+ allow_unused=True,
973
+ )
974
+ del ctx.input_tensors
975
+ del ctx.input_params
976
+ del output_tensors
977
+ return (None, None) + input_grads
978
+
979
+
980
+ def convert_module_to_f16(l):
981
+ """
982
+ Convert primitive modules to float16.
983
+ """
984
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
985
+ l.weight.data = l.weight.data.half()
986
+ if l.bias is not None:
987
+ l.bias.data = l.bias.data.half()
988
+
989
+
990
+ def convert_module_to_f32(l):
991
+ """
992
+ Convert primitive modules to float32, undoing convert_module_to_f16().
993
+ """
994
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
995
+ l.weight.data = l.weight.data.float()
996
+ if l.bias is not None:
997
+ l.bias.data = l.bias.data.float()