monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,702 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from collections.abc import Sequence
|
15
|
+
from typing import List
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
import torch.nn.functional as F
|
20
|
+
|
21
|
+
from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
|
22
|
+
from monai.utils import ensure_tuple_rep, optional_import
|
23
|
+
|
24
|
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
25
|
+
|
26
|
+
__all__ = ["AutoencoderKL"]
|
27
|
+
|
28
|
+
|
29
|
+
class AsymmetricPad(nn.Module):
|
30
|
+
"""
|
31
|
+
Pad the input tensor asymmetrically along every spatial dimension.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, spatial_dims: int) -> None:
|
38
|
+
super().__init__()
|
39
|
+
self.pad = (0, 1) * spatial_dims
|
40
|
+
|
41
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42
|
+
x = nn.functional.pad(x, self.pad, mode="constant", value=0.0)
|
43
|
+
return x
|
44
|
+
|
45
|
+
|
46
|
+
class AEKLDownsample(nn.Module):
|
47
|
+
"""
|
48
|
+
Convolution-based downsampling layer.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
|
52
|
+
in_channels: number of input channels.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, spatial_dims: int, in_channels: int) -> None:
|
56
|
+
super().__init__()
|
57
|
+
self.pad = AsymmetricPad(spatial_dims=spatial_dims)
|
58
|
+
|
59
|
+
self.conv = Convolution(
|
60
|
+
spatial_dims=spatial_dims,
|
61
|
+
in_channels=in_channels,
|
62
|
+
out_channels=in_channels,
|
63
|
+
strides=2,
|
64
|
+
kernel_size=3,
|
65
|
+
padding=0,
|
66
|
+
conv_only=True,
|
67
|
+
)
|
68
|
+
|
69
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70
|
+
x = self.pad(x)
|
71
|
+
x = self.conv(x)
|
72
|
+
return x
|
73
|
+
|
74
|
+
|
75
|
+
class AEKLResBlock(nn.Module):
|
76
|
+
"""
|
77
|
+
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
|
78
|
+
residual connection between input and output.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
82
|
+
in_channels: input channels to the layer.
|
83
|
+
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
|
84
|
+
channels is divisible by this number.
|
85
|
+
norm_eps: epsilon for the normalisation.
|
86
|
+
out_channels: number of output channels.
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int
|
91
|
+
) -> None:
|
92
|
+
super().__init__()
|
93
|
+
self.in_channels = in_channels
|
94
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
95
|
+
|
96
|
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
|
97
|
+
self.conv1 = Convolution(
|
98
|
+
spatial_dims=spatial_dims,
|
99
|
+
in_channels=self.in_channels,
|
100
|
+
out_channels=self.out_channels,
|
101
|
+
strides=1,
|
102
|
+
kernel_size=3,
|
103
|
+
padding=1,
|
104
|
+
conv_only=True,
|
105
|
+
)
|
106
|
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
|
107
|
+
self.conv2 = Convolution(
|
108
|
+
spatial_dims=spatial_dims,
|
109
|
+
in_channels=self.out_channels,
|
110
|
+
out_channels=self.out_channels,
|
111
|
+
strides=1,
|
112
|
+
kernel_size=3,
|
113
|
+
padding=1,
|
114
|
+
conv_only=True,
|
115
|
+
)
|
116
|
+
|
117
|
+
self.nin_shortcut: nn.Module
|
118
|
+
if self.in_channels != self.out_channels:
|
119
|
+
self.nin_shortcut = Convolution(
|
120
|
+
spatial_dims=spatial_dims,
|
121
|
+
in_channels=self.in_channels,
|
122
|
+
out_channels=self.out_channels,
|
123
|
+
strides=1,
|
124
|
+
kernel_size=1,
|
125
|
+
padding=0,
|
126
|
+
conv_only=True,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
self.nin_shortcut = nn.Identity()
|
130
|
+
|
131
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
132
|
+
h = x
|
133
|
+
h = self.norm1(h)
|
134
|
+
h = F.silu(h)
|
135
|
+
h = self.conv1(h)
|
136
|
+
|
137
|
+
h = self.norm2(h)
|
138
|
+
h = F.silu(h)
|
139
|
+
h = self.conv2(h)
|
140
|
+
|
141
|
+
x = self.nin_shortcut(x)
|
142
|
+
|
143
|
+
return x + h
|
144
|
+
|
145
|
+
|
146
|
+
class Encoder(nn.Module):
|
147
|
+
"""
|
148
|
+
Convolutional cascade that downsamples the image into a spatial latent space.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
152
|
+
in_channels: number of input channels.
|
153
|
+
channels: sequence of block output channels.
|
154
|
+
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
|
155
|
+
num_res_blocks: number of residual blocks (see _ResBlock) per level.
|
156
|
+
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
|
157
|
+
norm_eps: epsilon for the normalization.
|
158
|
+
attention_levels: indicate which level from num_channels contain an attention block.
|
159
|
+
with_nonlocal_attn: if True use non-local attention block.
|
160
|
+
"""
|
161
|
+
|
162
|
+
def __init__(
|
163
|
+
self,
|
164
|
+
spatial_dims: int,
|
165
|
+
in_channels: int,
|
166
|
+
channels: Sequence[int],
|
167
|
+
out_channels: int,
|
168
|
+
num_res_blocks: Sequence[int],
|
169
|
+
norm_num_groups: int,
|
170
|
+
norm_eps: float,
|
171
|
+
attention_levels: Sequence[bool],
|
172
|
+
with_nonlocal_attn: bool = True,
|
173
|
+
) -> None:
|
174
|
+
super().__init__()
|
175
|
+
self.spatial_dims = spatial_dims
|
176
|
+
self.in_channels = in_channels
|
177
|
+
self.channels = channels
|
178
|
+
self.out_channels = out_channels
|
179
|
+
self.num_res_blocks = num_res_blocks
|
180
|
+
self.norm_num_groups = norm_num_groups
|
181
|
+
self.norm_eps = norm_eps
|
182
|
+
self.attention_levels = attention_levels
|
183
|
+
|
184
|
+
blocks: List[nn.Module] = []
|
185
|
+
# Initial convolution
|
186
|
+
blocks.append(
|
187
|
+
Convolution(
|
188
|
+
spatial_dims=spatial_dims,
|
189
|
+
in_channels=in_channels,
|
190
|
+
out_channels=channels[0],
|
191
|
+
strides=1,
|
192
|
+
kernel_size=3,
|
193
|
+
padding=1,
|
194
|
+
conv_only=True,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
# Residual and downsampling blocks
|
199
|
+
output_channel = channels[0]
|
200
|
+
for i in range(len(channels)):
|
201
|
+
input_channel = output_channel
|
202
|
+
output_channel = channels[i]
|
203
|
+
is_final_block = i == len(channels) - 1
|
204
|
+
|
205
|
+
for _ in range(self.num_res_blocks[i]):
|
206
|
+
blocks.append(
|
207
|
+
AEKLResBlock(
|
208
|
+
spatial_dims=spatial_dims,
|
209
|
+
in_channels=input_channel,
|
210
|
+
norm_num_groups=norm_num_groups,
|
211
|
+
norm_eps=norm_eps,
|
212
|
+
out_channels=output_channel,
|
213
|
+
)
|
214
|
+
)
|
215
|
+
input_channel = output_channel
|
216
|
+
if attention_levels[i]:
|
217
|
+
blocks.append(
|
218
|
+
SpatialAttentionBlock(
|
219
|
+
spatial_dims=spatial_dims,
|
220
|
+
num_channels=input_channel,
|
221
|
+
norm_num_groups=norm_num_groups,
|
222
|
+
norm_eps=norm_eps,
|
223
|
+
)
|
224
|
+
)
|
225
|
+
|
226
|
+
if not is_final_block:
|
227
|
+
blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel))
|
228
|
+
# Non-local attention block
|
229
|
+
if with_nonlocal_attn is True:
|
230
|
+
blocks.append(
|
231
|
+
AEKLResBlock(
|
232
|
+
spatial_dims=spatial_dims,
|
233
|
+
in_channels=channels[-1],
|
234
|
+
norm_num_groups=norm_num_groups,
|
235
|
+
norm_eps=norm_eps,
|
236
|
+
out_channels=channels[-1],
|
237
|
+
)
|
238
|
+
)
|
239
|
+
|
240
|
+
blocks.append(
|
241
|
+
SpatialAttentionBlock(
|
242
|
+
spatial_dims=spatial_dims,
|
243
|
+
num_channels=channels[-1],
|
244
|
+
norm_num_groups=norm_num_groups,
|
245
|
+
norm_eps=norm_eps,
|
246
|
+
)
|
247
|
+
)
|
248
|
+
blocks.append(
|
249
|
+
AEKLResBlock(
|
250
|
+
spatial_dims=spatial_dims,
|
251
|
+
in_channels=channels[-1],
|
252
|
+
norm_num_groups=norm_num_groups,
|
253
|
+
norm_eps=norm_eps,
|
254
|
+
out_channels=channels[-1],
|
255
|
+
)
|
256
|
+
)
|
257
|
+
# Normalise and convert to latent size
|
258
|
+
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True))
|
259
|
+
blocks.append(
|
260
|
+
Convolution(
|
261
|
+
spatial_dims=self.spatial_dims,
|
262
|
+
in_channels=channels[-1],
|
263
|
+
out_channels=out_channels,
|
264
|
+
strides=1,
|
265
|
+
kernel_size=3,
|
266
|
+
padding=1,
|
267
|
+
conv_only=True,
|
268
|
+
)
|
269
|
+
)
|
270
|
+
|
271
|
+
self.blocks = nn.ModuleList(blocks)
|
272
|
+
|
273
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
274
|
+
for block in self.blocks:
|
275
|
+
x = block(x)
|
276
|
+
return x
|
277
|
+
|
278
|
+
|
279
|
+
class Decoder(nn.Module):
|
280
|
+
"""
|
281
|
+
Convolutional cascade upsampling from a spatial latent space into an image space.
|
282
|
+
|
283
|
+
Args:
|
284
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
285
|
+
channels: sequence of block output channels.
|
286
|
+
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
|
287
|
+
out_channels: number of output channels.
|
288
|
+
num_res_blocks: number of residual blocks (see _ResBlock) per level.
|
289
|
+
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
|
290
|
+
norm_eps: epsilon for the normalization.
|
291
|
+
attention_levels: indicate which level from num_channels contain an attention block.
|
292
|
+
with_nonlocal_attn: if True use non-local attention block.
|
293
|
+
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
|
294
|
+
"""
|
295
|
+
|
296
|
+
def __init__(
|
297
|
+
self,
|
298
|
+
spatial_dims: int,
|
299
|
+
channels: Sequence[int],
|
300
|
+
in_channels: int,
|
301
|
+
out_channels: int,
|
302
|
+
num_res_blocks: Sequence[int],
|
303
|
+
norm_num_groups: int,
|
304
|
+
norm_eps: float,
|
305
|
+
attention_levels: Sequence[bool],
|
306
|
+
with_nonlocal_attn: bool = True,
|
307
|
+
use_convtranspose: bool = False,
|
308
|
+
) -> None:
|
309
|
+
super().__init__()
|
310
|
+
self.spatial_dims = spatial_dims
|
311
|
+
self.channels = channels
|
312
|
+
self.in_channels = in_channels
|
313
|
+
self.out_channels = out_channels
|
314
|
+
self.num_res_blocks = num_res_blocks
|
315
|
+
self.norm_num_groups = norm_num_groups
|
316
|
+
self.norm_eps = norm_eps
|
317
|
+
self.attention_levels = attention_levels
|
318
|
+
|
319
|
+
reversed_block_out_channels = list(reversed(channels))
|
320
|
+
|
321
|
+
blocks: List[nn.Module] = []
|
322
|
+
|
323
|
+
# Initial convolution
|
324
|
+
blocks.append(
|
325
|
+
Convolution(
|
326
|
+
spatial_dims=spatial_dims,
|
327
|
+
in_channels=in_channels,
|
328
|
+
out_channels=reversed_block_out_channels[0],
|
329
|
+
strides=1,
|
330
|
+
kernel_size=3,
|
331
|
+
padding=1,
|
332
|
+
conv_only=True,
|
333
|
+
)
|
334
|
+
)
|
335
|
+
|
336
|
+
# Non-local attention block
|
337
|
+
if with_nonlocal_attn is True:
|
338
|
+
blocks.append(
|
339
|
+
AEKLResBlock(
|
340
|
+
spatial_dims=spatial_dims,
|
341
|
+
in_channels=reversed_block_out_channels[0],
|
342
|
+
norm_num_groups=norm_num_groups,
|
343
|
+
norm_eps=norm_eps,
|
344
|
+
out_channels=reversed_block_out_channels[0],
|
345
|
+
)
|
346
|
+
)
|
347
|
+
blocks.append(
|
348
|
+
SpatialAttentionBlock(
|
349
|
+
spatial_dims=spatial_dims,
|
350
|
+
num_channels=reversed_block_out_channels[0],
|
351
|
+
norm_num_groups=norm_num_groups,
|
352
|
+
norm_eps=norm_eps,
|
353
|
+
)
|
354
|
+
)
|
355
|
+
blocks.append(
|
356
|
+
AEKLResBlock(
|
357
|
+
spatial_dims=spatial_dims,
|
358
|
+
in_channels=reversed_block_out_channels[0],
|
359
|
+
norm_num_groups=norm_num_groups,
|
360
|
+
norm_eps=norm_eps,
|
361
|
+
out_channels=reversed_block_out_channels[0],
|
362
|
+
)
|
363
|
+
)
|
364
|
+
|
365
|
+
reversed_attention_levels = list(reversed(attention_levels))
|
366
|
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
367
|
+
block_out_ch = reversed_block_out_channels[0]
|
368
|
+
for i in range(len(reversed_block_out_channels)):
|
369
|
+
block_in_ch = block_out_ch
|
370
|
+
block_out_ch = reversed_block_out_channels[i]
|
371
|
+
is_final_block = i == len(channels) - 1
|
372
|
+
|
373
|
+
for _ in range(reversed_num_res_blocks[i]):
|
374
|
+
blocks.append(
|
375
|
+
AEKLResBlock(
|
376
|
+
spatial_dims=spatial_dims,
|
377
|
+
in_channels=block_in_ch,
|
378
|
+
norm_num_groups=norm_num_groups,
|
379
|
+
norm_eps=norm_eps,
|
380
|
+
out_channels=block_out_ch,
|
381
|
+
)
|
382
|
+
)
|
383
|
+
block_in_ch = block_out_ch
|
384
|
+
|
385
|
+
if reversed_attention_levels[i]:
|
386
|
+
blocks.append(
|
387
|
+
SpatialAttentionBlock(
|
388
|
+
spatial_dims=spatial_dims,
|
389
|
+
num_channels=block_in_ch,
|
390
|
+
norm_num_groups=norm_num_groups,
|
391
|
+
norm_eps=norm_eps,
|
392
|
+
)
|
393
|
+
)
|
394
|
+
|
395
|
+
if not is_final_block:
|
396
|
+
if use_convtranspose:
|
397
|
+
blocks.append(
|
398
|
+
Upsample(
|
399
|
+
spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch
|
400
|
+
)
|
401
|
+
)
|
402
|
+
else:
|
403
|
+
post_conv = Convolution(
|
404
|
+
spatial_dims=spatial_dims,
|
405
|
+
in_channels=block_in_ch,
|
406
|
+
out_channels=block_in_ch,
|
407
|
+
strides=1,
|
408
|
+
kernel_size=3,
|
409
|
+
padding=1,
|
410
|
+
conv_only=True,
|
411
|
+
)
|
412
|
+
blocks.append(
|
413
|
+
Upsample(
|
414
|
+
spatial_dims=spatial_dims,
|
415
|
+
mode="nontrainable",
|
416
|
+
in_channels=block_in_ch,
|
417
|
+
out_channels=block_in_ch,
|
418
|
+
interp_mode="nearest",
|
419
|
+
scale_factor=2.0,
|
420
|
+
post_conv=post_conv,
|
421
|
+
align_corners=None,
|
422
|
+
)
|
423
|
+
)
|
424
|
+
|
425
|
+
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
|
426
|
+
blocks.append(
|
427
|
+
Convolution(
|
428
|
+
spatial_dims=spatial_dims,
|
429
|
+
in_channels=block_in_ch,
|
430
|
+
out_channels=out_channels,
|
431
|
+
strides=1,
|
432
|
+
kernel_size=3,
|
433
|
+
padding=1,
|
434
|
+
conv_only=True,
|
435
|
+
)
|
436
|
+
)
|
437
|
+
|
438
|
+
self.blocks = nn.ModuleList(blocks)
|
439
|
+
|
440
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
441
|
+
for block in self.blocks:
|
442
|
+
x = block(x)
|
443
|
+
return x
|
444
|
+
|
445
|
+
|
446
|
+
class AutoencoderKL(nn.Module):
|
447
|
+
"""
|
448
|
+
Autoencoder model with KL-regularized latent space based on
|
449
|
+
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
|
450
|
+
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
|
451
|
+
|
452
|
+
Args:
|
453
|
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
454
|
+
in_channels: number of input channels.
|
455
|
+
out_channels: number of output channels.
|
456
|
+
num_res_blocks: number of residual blocks (see _ResBlock) per level.
|
457
|
+
channels: number of output channels for each block.
|
458
|
+
attention_levels: sequence of levels to add attention.
|
459
|
+
latent_channels: latent embedding dimension.
|
460
|
+
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
|
461
|
+
norm_eps: epsilon for the normalization.
|
462
|
+
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
|
463
|
+
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
|
464
|
+
use_checkpoint: if True, use activation checkpoint to save memory.
|
465
|
+
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
|
466
|
+
"""
|
467
|
+
|
468
|
+
def __init__(
|
469
|
+
self,
|
470
|
+
spatial_dims: int,
|
471
|
+
in_channels: int = 1,
|
472
|
+
out_channels: int = 1,
|
473
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
474
|
+
channels: Sequence[int] = (32, 64, 64, 64),
|
475
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
476
|
+
latent_channels: int = 3,
|
477
|
+
norm_num_groups: int = 32,
|
478
|
+
norm_eps: float = 1e-6,
|
479
|
+
with_encoder_nonlocal_attn: bool = True,
|
480
|
+
with_decoder_nonlocal_attn: bool = True,
|
481
|
+
use_checkpoint: bool = False,
|
482
|
+
use_convtranspose: bool = False,
|
483
|
+
) -> None:
|
484
|
+
super().__init__()
|
485
|
+
|
486
|
+
# All number of channels should be multiple of num_groups
|
487
|
+
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
|
488
|
+
raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
|
489
|
+
|
490
|
+
if len(channels) != len(attention_levels):
|
491
|
+
raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
|
492
|
+
|
493
|
+
if isinstance(num_res_blocks, int):
|
494
|
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
|
495
|
+
|
496
|
+
if len(num_res_blocks) != len(channels):
|
497
|
+
raise ValueError(
|
498
|
+
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
|
499
|
+
"`num_channels`."
|
500
|
+
)
|
501
|
+
|
502
|
+
self.encoder = Encoder(
|
503
|
+
spatial_dims=spatial_dims,
|
504
|
+
in_channels=in_channels,
|
505
|
+
channels=channels,
|
506
|
+
out_channels=latent_channels,
|
507
|
+
num_res_blocks=num_res_blocks,
|
508
|
+
norm_num_groups=norm_num_groups,
|
509
|
+
norm_eps=norm_eps,
|
510
|
+
attention_levels=attention_levels,
|
511
|
+
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
512
|
+
)
|
513
|
+
self.decoder = Decoder(
|
514
|
+
spatial_dims=spatial_dims,
|
515
|
+
channels=channels,
|
516
|
+
in_channels=latent_channels,
|
517
|
+
out_channels=out_channels,
|
518
|
+
num_res_blocks=num_res_blocks,
|
519
|
+
norm_num_groups=norm_num_groups,
|
520
|
+
norm_eps=norm_eps,
|
521
|
+
attention_levels=attention_levels,
|
522
|
+
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
523
|
+
use_convtranspose=use_convtranspose,
|
524
|
+
)
|
525
|
+
self.quant_conv_mu = Convolution(
|
526
|
+
spatial_dims=spatial_dims,
|
527
|
+
in_channels=latent_channels,
|
528
|
+
out_channels=latent_channels,
|
529
|
+
strides=1,
|
530
|
+
kernel_size=1,
|
531
|
+
padding=0,
|
532
|
+
conv_only=True,
|
533
|
+
)
|
534
|
+
self.quant_conv_log_sigma = Convolution(
|
535
|
+
spatial_dims=spatial_dims,
|
536
|
+
in_channels=latent_channels,
|
537
|
+
out_channels=latent_channels,
|
538
|
+
strides=1,
|
539
|
+
kernel_size=1,
|
540
|
+
padding=0,
|
541
|
+
conv_only=True,
|
542
|
+
)
|
543
|
+
self.post_quant_conv = Convolution(
|
544
|
+
spatial_dims=spatial_dims,
|
545
|
+
in_channels=latent_channels,
|
546
|
+
out_channels=latent_channels,
|
547
|
+
strides=1,
|
548
|
+
kernel_size=1,
|
549
|
+
padding=0,
|
550
|
+
conv_only=True,
|
551
|
+
)
|
552
|
+
self.latent_channels = latent_channels
|
553
|
+
self.use_checkpoint = use_checkpoint
|
554
|
+
|
555
|
+
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
556
|
+
"""
|
557
|
+
Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
|
558
|
+
|
559
|
+
Args:
|
560
|
+
x: BxCx[SPATIAL DIMS] tensor
|
561
|
+
|
562
|
+
"""
|
563
|
+
if self.use_checkpoint:
|
564
|
+
h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
|
565
|
+
else:
|
566
|
+
h = self.encoder(x)
|
567
|
+
|
568
|
+
z_mu = self.quant_conv_mu(h)
|
569
|
+
z_log_var = self.quant_conv_log_sigma(h)
|
570
|
+
z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
|
571
|
+
z_sigma = torch.exp(z_log_var / 2)
|
572
|
+
|
573
|
+
return z_mu, z_sigma
|
574
|
+
|
575
|
+
def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
|
576
|
+
"""
|
577
|
+
From the mean and sigma representations resulting of encoding an image through the latent space,
|
578
|
+
obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
|
579
|
+
adding the mean.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
|
583
|
+
z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
|
584
|
+
|
585
|
+
Returns:
|
586
|
+
sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
|
587
|
+
"""
|
588
|
+
eps = torch.randn_like(z_sigma)
|
589
|
+
z_vae = z_mu + eps * z_sigma
|
590
|
+
return z_vae
|
591
|
+
|
592
|
+
def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
|
593
|
+
"""
|
594
|
+
Encodes and decodes an input image.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
x: BxCx[SPATIAL DIMENSIONS] tensor.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
reconstructed image, of the same shape as input
|
601
|
+
"""
|
602
|
+
z_mu, _ = self.encode(x)
|
603
|
+
reconstruction = self.decode(z_mu)
|
604
|
+
return reconstruction
|
605
|
+
|
606
|
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
607
|
+
"""
|
608
|
+
Based on a latent space sample, forwards it through the Decoder.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
|
612
|
+
|
613
|
+
Returns:
|
614
|
+
decoded image tensor
|
615
|
+
"""
|
616
|
+
z = self.post_quant_conv(z)
|
617
|
+
dec: torch.Tensor
|
618
|
+
if self.use_checkpoint:
|
619
|
+
dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
|
620
|
+
else:
|
621
|
+
dec = self.decoder(z)
|
622
|
+
return dec
|
623
|
+
|
624
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
625
|
+
z_mu, z_sigma = self.encode(x)
|
626
|
+
z = self.sampling(z_mu, z_sigma)
|
627
|
+
reconstruction = self.decode(z)
|
628
|
+
return reconstruction, z_mu, z_sigma
|
629
|
+
|
630
|
+
def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
|
631
|
+
z_mu, z_sigma = self.encode(x)
|
632
|
+
z = self.sampling(z_mu, z_sigma)
|
633
|
+
return z
|
634
|
+
|
635
|
+
def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
|
636
|
+
image = self.decode(z)
|
637
|
+
return image
|
638
|
+
|
639
|
+
def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
|
640
|
+
"""
|
641
|
+
Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
|
642
|
+
|
643
|
+
Args:
|
644
|
+
old_state_dict: state dict from the old AutoencoderKL model.
|
645
|
+
"""
|
646
|
+
|
647
|
+
new_state_dict = self.state_dict()
|
648
|
+
# if all keys match, just load the state dict
|
649
|
+
if all(k in new_state_dict for k in old_state_dict):
|
650
|
+
print("All keys match, loading state dict.")
|
651
|
+
self.load_state_dict(old_state_dict)
|
652
|
+
return
|
653
|
+
|
654
|
+
if verbose:
|
655
|
+
# print all new_state_dict keys that are not in old_state_dict
|
656
|
+
for k in new_state_dict:
|
657
|
+
if k not in old_state_dict:
|
658
|
+
print(f"key {k} not found in old state dict")
|
659
|
+
# and vice versa
|
660
|
+
print("----------------------------------------------")
|
661
|
+
for k in old_state_dict:
|
662
|
+
if k not in new_state_dict:
|
663
|
+
print(f"key {k} not found in new state dict")
|
664
|
+
|
665
|
+
# copy over all matching keys
|
666
|
+
for k in new_state_dict:
|
667
|
+
if k in old_state_dict:
|
668
|
+
new_state_dict[k] = old_state_dict[k]
|
669
|
+
|
670
|
+
# fix the attention blocks
|
671
|
+
attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
|
672
|
+
for block in attention_blocks:
|
673
|
+
new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
|
674
|
+
[
|
675
|
+
old_state_dict[f"{block}.to_q.weight"],
|
676
|
+
old_state_dict[f"{block}.to_k.weight"],
|
677
|
+
old_state_dict[f"{block}.to_v.weight"],
|
678
|
+
],
|
679
|
+
dim=0,
|
680
|
+
)
|
681
|
+
new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat(
|
682
|
+
[
|
683
|
+
old_state_dict[f"{block}.to_q.bias"],
|
684
|
+
old_state_dict[f"{block}.to_k.bias"],
|
685
|
+
old_state_dict[f"{block}.to_v.bias"],
|
686
|
+
],
|
687
|
+
dim=0,
|
688
|
+
)
|
689
|
+
# old version did not have a projection so set these to the identity
|
690
|
+
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
|
691
|
+
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
|
692
|
+
)
|
693
|
+
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
|
694
|
+
new_state_dict[f"{block}.attn.out_proj.bias"].shape
|
695
|
+
)
|
696
|
+
|
697
|
+
# fix the upsample conv blocks which were renamed postconv
|
698
|
+
for k in new_state_dict:
|
699
|
+
if "postconv" in k:
|
700
|
+
old_name = k.replace("postconv", "conv")
|
701
|
+
new_state_dict[k] = old_state_dict[old_name]
|
702
|
+
self.load_state_dict(new_state_dict)
|