monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from collections.abc import Sequence
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
|
20
|
+
from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
|
21
|
+
from monai.networks.blocks.spade_norm import SPADE
|
22
|
+
from monai.networks.nets.autoencoderkl import Encoder
|
23
|
+
from monai.utils import ensure_tuple_rep
|
24
|
+
|
25
|
+
__all__ = ["SPADEAutoencoderKL"]
|
26
|
+
|
27
|
+
|
28
|
+
class SPADEResBlock(nn.Module):
|
29
|
+
"""
|
30
|
+
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
|
31
|
+
residual connection between input and output.
|
32
|
+
Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
|
33
|
+
|
34
|
+
Args:
|
35
|
+
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
|
36
|
+
in_channels: input channels to the layer.
|
37
|
+
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
|
38
|
+
channels is divisible by this number.
|
39
|
+
norm_eps: epsilon for the normalisation.
|
40
|
+
out_channels: number of output channels.
|
41
|
+
label_nc: number of semantic channels for SPADE normalisation
|
42
|
+
spade_intermediate_channels: number of intermediate channels for SPADE block layer
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
spatial_dims: int,
|
48
|
+
in_channels: int,
|
49
|
+
norm_num_groups: int,
|
50
|
+
norm_eps: float,
|
51
|
+
out_channels: int,
|
52
|
+
label_nc: int,
|
53
|
+
spade_intermediate_channels: int,
|
54
|
+
) -> None:
|
55
|
+
super().__init__()
|
56
|
+
self.in_channels = in_channels
|
57
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
58
|
+
self.norm1 = SPADE(
|
59
|
+
label_nc=label_nc,
|
60
|
+
norm_nc=in_channels,
|
61
|
+
norm="GROUP",
|
62
|
+
norm_params={"num_groups": norm_num_groups, "affine": False},
|
63
|
+
hidden_channels=spade_intermediate_channels,
|
64
|
+
kernel_size=3,
|
65
|
+
spatial_dims=spatial_dims,
|
66
|
+
)
|
67
|
+
self.conv1 = Convolution(
|
68
|
+
spatial_dims=spatial_dims,
|
69
|
+
in_channels=self.in_channels,
|
70
|
+
out_channels=self.out_channels,
|
71
|
+
strides=1,
|
72
|
+
kernel_size=3,
|
73
|
+
padding=1,
|
74
|
+
conv_only=True,
|
75
|
+
)
|
76
|
+
self.norm2 = SPADE(
|
77
|
+
label_nc=label_nc,
|
78
|
+
norm_nc=out_channels,
|
79
|
+
norm="GROUP",
|
80
|
+
norm_params={"num_groups": norm_num_groups, "affine": False},
|
81
|
+
hidden_channels=spade_intermediate_channels,
|
82
|
+
kernel_size=3,
|
83
|
+
spatial_dims=spatial_dims,
|
84
|
+
)
|
85
|
+
self.conv2 = Convolution(
|
86
|
+
spatial_dims=spatial_dims,
|
87
|
+
in_channels=self.out_channels,
|
88
|
+
out_channels=self.out_channels,
|
89
|
+
strides=1,
|
90
|
+
kernel_size=3,
|
91
|
+
padding=1,
|
92
|
+
conv_only=True,
|
93
|
+
)
|
94
|
+
|
95
|
+
self.nin_shortcut: nn.Module
|
96
|
+
if self.in_channels != self.out_channels:
|
97
|
+
self.nin_shortcut = Convolution(
|
98
|
+
spatial_dims=spatial_dims,
|
99
|
+
in_channels=self.in_channels,
|
100
|
+
out_channels=self.out_channels,
|
101
|
+
strides=1,
|
102
|
+
kernel_size=1,
|
103
|
+
padding=0,
|
104
|
+
conv_only=True,
|
105
|
+
)
|
106
|
+
else:
|
107
|
+
self.nin_shortcut = nn.Identity()
|
108
|
+
|
109
|
+
def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
|
110
|
+
h = x
|
111
|
+
h = self.norm1(h, seg)
|
112
|
+
h = F.silu(h)
|
113
|
+
h = self.conv1(h)
|
114
|
+
h = self.norm2(h, seg)
|
115
|
+
h = F.silu(h)
|
116
|
+
h = self.conv2(h)
|
117
|
+
|
118
|
+
x = self.nin_shortcut(x)
|
119
|
+
|
120
|
+
return x + h
|
121
|
+
|
122
|
+
|
123
|
+
class SPADEDecoder(nn.Module):
|
124
|
+
"""
|
125
|
+
Convolutional cascade upsampling from a spatial latent space into an image space.
|
126
|
+
Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
|
127
|
+
|
128
|
+
Args:
|
129
|
+
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
|
130
|
+
channels: sequence of block output channels.
|
131
|
+
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
|
132
|
+
out_channels: number of output channels.
|
133
|
+
num_res_blocks: number of residual blocks (see ResBlock) per level.
|
134
|
+
norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
|
135
|
+
norm_eps: epsilon for the normalization.
|
136
|
+
attention_levels: indicate which level from channels contain an attention block.
|
137
|
+
label_nc: number of semantic channels for SPADE normalisation.
|
138
|
+
with_nonlocal_attn: if True use non-local attention block.
|
139
|
+
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
140
|
+
"""
|
141
|
+
|
142
|
+
def __init__(
|
143
|
+
self,
|
144
|
+
spatial_dims: int,
|
145
|
+
channels: Sequence[int],
|
146
|
+
in_channels: int,
|
147
|
+
out_channels: int,
|
148
|
+
num_res_blocks: Sequence[int],
|
149
|
+
norm_num_groups: int,
|
150
|
+
norm_eps: float,
|
151
|
+
attention_levels: Sequence[bool],
|
152
|
+
label_nc: int,
|
153
|
+
with_nonlocal_attn: bool = True,
|
154
|
+
spade_intermediate_channels: int = 128,
|
155
|
+
) -> None:
|
156
|
+
super().__init__()
|
157
|
+
self.spatial_dims = spatial_dims
|
158
|
+
self.channels = channels
|
159
|
+
self.in_channels = in_channels
|
160
|
+
self.out_channels = out_channels
|
161
|
+
self.num_res_blocks = num_res_blocks
|
162
|
+
self.norm_num_groups = norm_num_groups
|
163
|
+
self.norm_eps = norm_eps
|
164
|
+
self.attention_levels = attention_levels
|
165
|
+
self.label_nc = label_nc
|
166
|
+
|
167
|
+
reversed_block_out_channels = list(reversed(channels))
|
168
|
+
|
169
|
+
blocks: list[nn.Module] = []
|
170
|
+
|
171
|
+
# Initial convolution
|
172
|
+
blocks.append(
|
173
|
+
Convolution(
|
174
|
+
spatial_dims=spatial_dims,
|
175
|
+
in_channels=in_channels,
|
176
|
+
out_channels=reversed_block_out_channels[0],
|
177
|
+
strides=1,
|
178
|
+
kernel_size=3,
|
179
|
+
padding=1,
|
180
|
+
conv_only=True,
|
181
|
+
)
|
182
|
+
)
|
183
|
+
|
184
|
+
# Non-local attention block
|
185
|
+
if with_nonlocal_attn is True:
|
186
|
+
blocks.append(
|
187
|
+
SPADEResBlock(
|
188
|
+
spatial_dims=spatial_dims,
|
189
|
+
in_channels=reversed_block_out_channels[0],
|
190
|
+
norm_num_groups=norm_num_groups,
|
191
|
+
norm_eps=norm_eps,
|
192
|
+
out_channels=reversed_block_out_channels[0],
|
193
|
+
label_nc=label_nc,
|
194
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
blocks.append(
|
198
|
+
SpatialAttentionBlock(
|
199
|
+
spatial_dims=spatial_dims,
|
200
|
+
num_channels=reversed_block_out_channels[0],
|
201
|
+
norm_num_groups=norm_num_groups,
|
202
|
+
norm_eps=norm_eps,
|
203
|
+
)
|
204
|
+
)
|
205
|
+
blocks.append(
|
206
|
+
SPADEResBlock(
|
207
|
+
spatial_dims=spatial_dims,
|
208
|
+
in_channels=reversed_block_out_channels[0],
|
209
|
+
norm_num_groups=norm_num_groups,
|
210
|
+
norm_eps=norm_eps,
|
211
|
+
out_channels=reversed_block_out_channels[0],
|
212
|
+
label_nc=label_nc,
|
213
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
214
|
+
)
|
215
|
+
)
|
216
|
+
|
217
|
+
reversed_attention_levels = list(reversed(attention_levels))
|
218
|
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
219
|
+
block_out_ch = reversed_block_out_channels[0]
|
220
|
+
for i in range(len(reversed_block_out_channels)):
|
221
|
+
block_in_ch = block_out_ch
|
222
|
+
block_out_ch = reversed_block_out_channels[i]
|
223
|
+
is_final_block = i == len(channels) - 1
|
224
|
+
|
225
|
+
for _ in range(reversed_num_res_blocks[i]):
|
226
|
+
blocks.append(
|
227
|
+
SPADEResBlock(
|
228
|
+
spatial_dims=spatial_dims,
|
229
|
+
in_channels=block_in_ch,
|
230
|
+
norm_num_groups=norm_num_groups,
|
231
|
+
norm_eps=norm_eps,
|
232
|
+
out_channels=block_out_ch,
|
233
|
+
label_nc=label_nc,
|
234
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
block_in_ch = block_out_ch
|
238
|
+
|
239
|
+
if reversed_attention_levels[i]:
|
240
|
+
blocks.append(
|
241
|
+
SpatialAttentionBlock(
|
242
|
+
spatial_dims=spatial_dims,
|
243
|
+
num_channels=block_in_ch,
|
244
|
+
norm_num_groups=norm_num_groups,
|
245
|
+
norm_eps=norm_eps,
|
246
|
+
)
|
247
|
+
)
|
248
|
+
|
249
|
+
if not is_final_block:
|
250
|
+
post_conv = Convolution(
|
251
|
+
spatial_dims=spatial_dims,
|
252
|
+
in_channels=block_in_ch,
|
253
|
+
out_channels=block_in_ch,
|
254
|
+
strides=1,
|
255
|
+
kernel_size=3,
|
256
|
+
padding=1,
|
257
|
+
conv_only=True,
|
258
|
+
)
|
259
|
+
blocks.append(
|
260
|
+
Upsample(
|
261
|
+
spatial_dims=spatial_dims,
|
262
|
+
mode="nontrainable",
|
263
|
+
in_channels=block_in_ch,
|
264
|
+
out_channels=block_in_ch,
|
265
|
+
interp_mode="nearest",
|
266
|
+
scale_factor=2.0,
|
267
|
+
post_conv=post_conv,
|
268
|
+
align_corners=None,
|
269
|
+
)
|
270
|
+
)
|
271
|
+
|
272
|
+
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
|
273
|
+
blocks.append(
|
274
|
+
Convolution(
|
275
|
+
spatial_dims=spatial_dims,
|
276
|
+
in_channels=block_in_ch,
|
277
|
+
out_channels=out_channels,
|
278
|
+
strides=1,
|
279
|
+
kernel_size=3,
|
280
|
+
padding=1,
|
281
|
+
conv_only=True,
|
282
|
+
)
|
283
|
+
)
|
284
|
+
|
285
|
+
self.blocks = nn.ModuleList(blocks)
|
286
|
+
|
287
|
+
def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
|
288
|
+
for block in self.blocks:
|
289
|
+
if isinstance(block, SPADEResBlock):
|
290
|
+
x = block(x, seg)
|
291
|
+
else:
|
292
|
+
x = block(x)
|
293
|
+
return x
|
294
|
+
|
295
|
+
|
296
|
+
class SPADEAutoencoderKL(nn.Module):
|
297
|
+
"""
|
298
|
+
Autoencoder model with KL-regularized latent space based on
|
299
|
+
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
|
300
|
+
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
|
301
|
+
Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
|
302
|
+
|
303
|
+
Args:
|
304
|
+
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
|
305
|
+
label_nc: number of semantic channels for SPADE normalisation.
|
306
|
+
in_channels: number of input channels.
|
307
|
+
out_channels: number of output channels.
|
308
|
+
num_res_blocks: number of residual blocks (see ResBlock) per level.
|
309
|
+
channels: sequence of block output channels.
|
310
|
+
attention_levels: sequence of levels to add attention.
|
311
|
+
latent_channels: latent embedding dimension.
|
312
|
+
norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
|
313
|
+
norm_eps: epsilon for the normalization.
|
314
|
+
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
|
315
|
+
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
|
316
|
+
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
317
|
+
"""
|
318
|
+
|
319
|
+
def __init__(
|
320
|
+
self,
|
321
|
+
spatial_dims: int,
|
322
|
+
label_nc: int,
|
323
|
+
in_channels: int = 1,
|
324
|
+
out_channels: int = 1,
|
325
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
326
|
+
channels: Sequence[int] = (32, 64, 64, 64),
|
327
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
328
|
+
latent_channels: int = 3,
|
329
|
+
norm_num_groups: int = 32,
|
330
|
+
norm_eps: float = 1e-6,
|
331
|
+
with_encoder_nonlocal_attn: bool = True,
|
332
|
+
with_decoder_nonlocal_attn: bool = True,
|
333
|
+
spade_intermediate_channels: int = 128,
|
334
|
+
) -> None:
|
335
|
+
super().__init__()
|
336
|
+
|
337
|
+
# All number of channels should be multiple of num_groups
|
338
|
+
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
|
339
|
+
raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups")
|
340
|
+
|
341
|
+
if len(channels) != len(attention_levels):
|
342
|
+
raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels")
|
343
|
+
|
344
|
+
if isinstance(num_res_blocks, int):
|
345
|
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
|
346
|
+
|
347
|
+
if len(num_res_blocks) != len(channels):
|
348
|
+
raise ValueError(
|
349
|
+
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
|
350
|
+
"`channels`."
|
351
|
+
)
|
352
|
+
|
353
|
+
self.encoder = Encoder(
|
354
|
+
spatial_dims=spatial_dims,
|
355
|
+
in_channels=in_channels,
|
356
|
+
channels=channels,
|
357
|
+
out_channels=latent_channels,
|
358
|
+
num_res_blocks=num_res_blocks,
|
359
|
+
norm_num_groups=norm_num_groups,
|
360
|
+
norm_eps=norm_eps,
|
361
|
+
attention_levels=attention_levels,
|
362
|
+
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
363
|
+
)
|
364
|
+
self.decoder = SPADEDecoder(
|
365
|
+
spatial_dims=spatial_dims,
|
366
|
+
channels=channels,
|
367
|
+
in_channels=latent_channels,
|
368
|
+
out_channels=out_channels,
|
369
|
+
num_res_blocks=num_res_blocks,
|
370
|
+
norm_num_groups=norm_num_groups,
|
371
|
+
norm_eps=norm_eps,
|
372
|
+
attention_levels=attention_levels,
|
373
|
+
label_nc=label_nc,
|
374
|
+
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
375
|
+
spade_intermediate_channels=spade_intermediate_channels,
|
376
|
+
)
|
377
|
+
self.quant_conv_mu = Convolution(
|
378
|
+
spatial_dims=spatial_dims,
|
379
|
+
in_channels=latent_channels,
|
380
|
+
out_channels=latent_channels,
|
381
|
+
strides=1,
|
382
|
+
kernel_size=1,
|
383
|
+
padding=0,
|
384
|
+
conv_only=True,
|
385
|
+
)
|
386
|
+
self.quant_conv_log_sigma = Convolution(
|
387
|
+
spatial_dims=spatial_dims,
|
388
|
+
in_channels=latent_channels,
|
389
|
+
out_channels=latent_channels,
|
390
|
+
strides=1,
|
391
|
+
kernel_size=1,
|
392
|
+
padding=0,
|
393
|
+
conv_only=True,
|
394
|
+
)
|
395
|
+
self.post_quant_conv = Convolution(
|
396
|
+
spatial_dims=spatial_dims,
|
397
|
+
in_channels=latent_channels,
|
398
|
+
out_channels=latent_channels,
|
399
|
+
strides=1,
|
400
|
+
kernel_size=1,
|
401
|
+
padding=0,
|
402
|
+
conv_only=True,
|
403
|
+
)
|
404
|
+
self.latent_channels = latent_channels
|
405
|
+
|
406
|
+
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
407
|
+
"""
|
408
|
+
Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
x: BxCx[SPATIAL DIMS] tensor
|
412
|
+
|
413
|
+
"""
|
414
|
+
h = self.encoder(x)
|
415
|
+
z_mu = self.quant_conv_mu(h)
|
416
|
+
z_log_var = self.quant_conv_log_sigma(h)
|
417
|
+
z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
|
418
|
+
z_sigma = torch.exp(z_log_var / 2)
|
419
|
+
|
420
|
+
return z_mu, z_sigma
|
421
|
+
|
422
|
+
def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
|
423
|
+
"""
|
424
|
+
From the mean and sigma representations resulting of encoding an image through the latent space,
|
425
|
+
obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
|
426
|
+
adding the mean.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
|
430
|
+
z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
|
431
|
+
|
432
|
+
Returns:
|
433
|
+
sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
|
434
|
+
"""
|
435
|
+
eps = torch.randn_like(z_sigma)
|
436
|
+
z_vae = z_mu + eps * z_sigma
|
437
|
+
return z_vae
|
438
|
+
|
439
|
+
def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
|
440
|
+
"""
|
441
|
+
Encodes and decodes an input image.
|
442
|
+
|
443
|
+
Args:
|
444
|
+
x: BxCx[SPATIAL DIMENSIONS] tensor.
|
445
|
+
seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
|
446
|
+
Returns:
|
447
|
+
reconstructed image, of the same shape as input
|
448
|
+
"""
|
449
|
+
z_mu, _ = self.encode(x)
|
450
|
+
reconstruction = self.decode(z_mu, seg)
|
451
|
+
return reconstruction
|
452
|
+
|
453
|
+
def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
|
454
|
+
"""
|
455
|
+
Based on a latent space sample, forwards it through the Decoder.
|
456
|
+
|
457
|
+
Args:
|
458
|
+
z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
|
459
|
+
seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
|
460
|
+
Returns:
|
461
|
+
decoded image tensor
|
462
|
+
"""
|
463
|
+
z = self.post_quant_conv(z)
|
464
|
+
dec: torch.Tensor = self.decoder(z, seg)
|
465
|
+
return dec
|
466
|
+
|
467
|
+
def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
468
|
+
z_mu, z_sigma = self.encode(x)
|
469
|
+
z = self.sampling(z_mu, z_sigma)
|
470
|
+
reconstruction = self.decode(z, seg)
|
471
|
+
return reconstruction, z_mu, z_sigma
|
472
|
+
|
473
|
+
def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
|
474
|
+
z_mu, z_sigma = self.encode(x)
|
475
|
+
z = self.sampling(z_mu, z_sigma)
|
476
|
+
return z
|
477
|
+
|
478
|
+
def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
|
479
|
+
image = self.decode(z, seg)
|
480
|
+
return image
|