careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +39 -17
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
careamics/models/lvae/lvae.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Ladder VAE (LVAE) Model
|
|
2
|
+
Ladder VAE (LVAE) Model.
|
|
3
3
|
|
|
4
|
-
The current implementation is based on "Interpretable Unsupervised Diversity Denoising
|
|
4
|
+
The current implementation is based on "Interpretable Unsupervised Diversity Denoising
|
|
5
|
+
and Artefact Removal, Prakash et al."
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
from collections.abc import Iterable
|
|
8
|
-
from typing import
|
|
9
|
+
from typing import Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
11
12
|
import torch
|
|
@@ -17,42 +18,80 @@ from ..activation import get_activation
|
|
|
17
18
|
from .layers import (
|
|
18
19
|
BottomUpDeterministicResBlock,
|
|
19
20
|
BottomUpLayer,
|
|
21
|
+
GateLayer,
|
|
20
22
|
TopDownDeterministicResBlock,
|
|
21
23
|
TopDownLayer,
|
|
22
24
|
)
|
|
23
|
-
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
25
|
+
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
@register_model("LVAE")
|
|
27
29
|
class LadderVAE(nn.Module):
|
|
30
|
+
"""
|
|
31
|
+
Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
input_shape : int
|
|
36
|
+
The size of the input image.
|
|
37
|
+
output_channels : int
|
|
38
|
+
The number of output channels.
|
|
39
|
+
multiscale_count : int
|
|
40
|
+
The number of scales for multiscale processing.
|
|
41
|
+
z_dims : list[int]
|
|
42
|
+
The dimensions of the latent space for each layer.
|
|
43
|
+
encoder_n_filters : int
|
|
44
|
+
The number of filters in the encoder.
|
|
45
|
+
decoder_n_filters : int
|
|
46
|
+
The number of filters in the decoder.
|
|
47
|
+
encoder_conv_strides : list[int]
|
|
48
|
+
The strides for the conv layers encoder.
|
|
49
|
+
decoder_conv_strides : list[int]
|
|
50
|
+
The strides for the conv layers decoder.
|
|
51
|
+
encoder_dropout : float
|
|
52
|
+
The dropout rate for the encoder.
|
|
53
|
+
decoder_dropout : float
|
|
54
|
+
The dropout rate for the decoder.
|
|
55
|
+
nonlinearity : str
|
|
56
|
+
The nonlinearity function to use.
|
|
57
|
+
predict_logvar : bool
|
|
58
|
+
Whether to predict the log variance.
|
|
59
|
+
analytical_kl : bool
|
|
60
|
+
Whether to use analytical KL divergence.
|
|
61
|
+
|
|
62
|
+
Raises
|
|
63
|
+
------
|
|
64
|
+
NotImplementedError
|
|
65
|
+
If only 2D convolutions are supported.
|
|
66
|
+
"""
|
|
28
67
|
|
|
29
68
|
def __init__(
|
|
30
69
|
self,
|
|
31
70
|
input_shape: int,
|
|
32
71
|
output_channels: int,
|
|
33
72
|
multiscale_count: int,
|
|
34
|
-
z_dims:
|
|
73
|
+
z_dims: list[int],
|
|
35
74
|
encoder_n_filters: int,
|
|
36
75
|
decoder_n_filters: int,
|
|
76
|
+
encoder_conv_strides: list[int],
|
|
77
|
+
decoder_conv_strides: list[int],
|
|
37
78
|
encoder_dropout: float,
|
|
38
79
|
decoder_dropout: float,
|
|
39
80
|
nonlinearity: str,
|
|
40
81
|
predict_logvar: bool,
|
|
41
82
|
analytical_kl: bool,
|
|
42
83
|
):
|
|
43
|
-
"""
|
|
44
|
-
Constructor.
|
|
45
|
-
|
|
46
|
-
Parameters
|
|
47
|
-
----------
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
84
|
super().__init__()
|
|
51
85
|
|
|
52
86
|
# -------------------------------------------------------
|
|
53
87
|
# Customizable attributes
|
|
54
88
|
self.image_size = input_shape
|
|
89
|
+
"""Input image size. (Z, Y, X) or (Y, X) if the data is 2D."""
|
|
90
|
+
# TODO: we need to be careful with this since used to be an int.
|
|
91
|
+
# the tuple of shapes used to be `self.input_shape`.
|
|
55
92
|
self.target_ch = output_channels
|
|
93
|
+
self.encoder_conv_strides = encoder_conv_strides
|
|
94
|
+
self.decoder_conv_strides = decoder_conv_strides
|
|
56
95
|
self._multiscale_count = multiscale_count
|
|
57
96
|
self.z_dims = z_dims
|
|
58
97
|
self.encoder_n_filters = encoder_n_filters
|
|
@@ -80,7 +119,6 @@ class LadderVAE(nn.Module):
|
|
|
80
119
|
self.merge_type = "residual"
|
|
81
120
|
self.no_initial_downscaling = True
|
|
82
121
|
self.skip_bottomk_buvalues = 0
|
|
83
|
-
self.non_stochastic_version = False
|
|
84
122
|
self.stochastic_skip = True
|
|
85
123
|
self.learn_top_prior = True
|
|
86
124
|
self.res_block_type = "bacdbacd" # TODO remove !
|
|
@@ -91,9 +129,7 @@ class LadderVAE(nn.Module):
|
|
|
91
129
|
self._enable_topdown_normalize_factor = True
|
|
92
130
|
|
|
93
131
|
# Attributes that handle LC -> Hardcoded
|
|
94
|
-
self.enable_multiscale =
|
|
95
|
-
self._multiscale_count is not None and self._multiscale_count > 1
|
|
96
|
-
)
|
|
132
|
+
self.enable_multiscale = self._multiscale_count > 1
|
|
97
133
|
self.multiscale_retain_spatial_dims = True
|
|
98
134
|
self.multiscale_lowres_separate_branch = False
|
|
99
135
|
self.multiscale_decoder_retain_spatial_dims = (
|
|
@@ -102,14 +138,6 @@ class LadderVAE(nn.Module):
|
|
|
102
138
|
|
|
103
139
|
# Derived attributes
|
|
104
140
|
self.n_layers = len(self.z_dims)
|
|
105
|
-
self.encoder_no_padding_mode = (
|
|
106
|
-
self.encoder_res_block_skip_padding is True
|
|
107
|
-
and self.encoder_res_block_kernel > 1
|
|
108
|
-
)
|
|
109
|
-
self.decoder_no_padding_mode = (
|
|
110
|
-
self.decoder_res_block_skip_padding is True
|
|
111
|
-
and self.decoder_res_block_kernel > 1
|
|
112
|
-
)
|
|
113
141
|
|
|
114
142
|
# Others...
|
|
115
143
|
self._tethered_to_input = False
|
|
@@ -127,19 +155,41 @@ class LadderVAE(nn.Module):
|
|
|
127
155
|
|
|
128
156
|
# -------------------------------------------------------
|
|
129
157
|
# Data attributes
|
|
130
|
-
self.color_ch = 1
|
|
131
|
-
self.img_shape = (self.image_size, self.image_size)
|
|
158
|
+
self.color_ch = 1 # TODO for now we only support 1 channel
|
|
132
159
|
self.normalized_input = True
|
|
133
160
|
# -------------------------------------------------------
|
|
134
161
|
|
|
135
162
|
# -------------------------------------------------------
|
|
136
163
|
# Loss attributes
|
|
137
|
-
self._restricted_kl = False # HC
|
|
138
164
|
# enabling reconstruction loss on mixed input
|
|
139
165
|
self.mixed_rec_w = 0
|
|
140
166
|
self.nbr_consistency_w = 0
|
|
141
167
|
|
|
142
168
|
# -------------------------------------------------------
|
|
169
|
+
# 3D related stuff
|
|
170
|
+
self._mode_3D = len(self.image_size) == 3 # TODO refac
|
|
171
|
+
self._model_3D_depth = self.image_size[0] if self._mode_3D else 1
|
|
172
|
+
self._decoder_mode_3D = len(self.decoder_conv_strides) == 3
|
|
173
|
+
if self._mode_3D and not self._decoder_mode_3D:
|
|
174
|
+
assert self._model_3D_depth % 2 == 1, "3D model depth should be odd"
|
|
175
|
+
assert (
|
|
176
|
+
self._mode_3D is True or self._decoder_mode_3D is False
|
|
177
|
+
), "Decoder cannot be 3D when encoder is 2D"
|
|
178
|
+
self._squish3d = self._mode_3D and not self._decoder_mode_3D
|
|
179
|
+
self._3D_squisher = (
|
|
180
|
+
None
|
|
181
|
+
if not self._squish3d
|
|
182
|
+
else nn.ModuleList(
|
|
183
|
+
[
|
|
184
|
+
GateLayer(
|
|
185
|
+
channels=self.encoder_n_filters,
|
|
186
|
+
conv_strides=self.encoder_conv_strides,
|
|
187
|
+
)
|
|
188
|
+
for k in range(len(self.z_dims))
|
|
189
|
+
]
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
# TODO: this bit is in the Ashesh's confusing-hacky style... Can we do better?
|
|
143
193
|
|
|
144
194
|
# -------------------------------------------------------
|
|
145
195
|
# # Training attributes
|
|
@@ -168,6 +218,11 @@ class LadderVAE(nn.Module):
|
|
|
168
218
|
### CREATE MODEL BLOCKS
|
|
169
219
|
# First bottom-up layer: change num channels + downsample by factor 2
|
|
170
220
|
# unless we want to prevent this
|
|
221
|
+
self.encoder_conv_op = getattr(nn, f"Conv{len(self.encoder_conv_strides)}d")
|
|
222
|
+
# TODO these should be defined for all layers here ?
|
|
223
|
+
self.decoder_conv_op = getattr(nn, f"Conv{len(self.decoder_conv_strides)}d")
|
|
224
|
+
# TODO: would be more readable to have a derived parameters to use like
|
|
225
|
+
# `conv_dims = len(self.encoder_conv_strides)` and then use `Conv{conv_dims}d`
|
|
171
226
|
stride = 1 if self.no_initial_downscaling else 2
|
|
172
227
|
self.first_bottom_up = self.create_first_bottom_up(stride)
|
|
173
228
|
|
|
@@ -191,7 +246,7 @@ class LadderVAE(nn.Module):
|
|
|
191
246
|
|
|
192
247
|
# Output layer --> Project to target_ch many channels
|
|
193
248
|
logvar_ch_needed = self.predict_logvar is not None
|
|
194
|
-
self.output_layer = self.parameter_net =
|
|
249
|
+
self.output_layer = self.parameter_net = self.decoder_conv_op(
|
|
195
250
|
self.decoder_n_filters,
|
|
196
251
|
self.target_ch * (1 + logvar_ch_needed),
|
|
197
252
|
kernel_size=3,
|
|
@@ -205,6 +260,7 @@ class LadderVAE(nn.Module):
|
|
|
205
260
|
# PSNR computation on validation.
|
|
206
261
|
# self.label1_psnr = RunningPSNR()
|
|
207
262
|
# self.label2_psnr = RunningPSNR()
|
|
263
|
+
# TODO: did you add this?
|
|
208
264
|
|
|
209
265
|
# msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}'
|
|
210
266
|
# msg += f' TargetCh: {self.target_ch}'
|
|
@@ -217,7 +273,8 @@ class LadderVAE(nn.Module):
|
|
|
217
273
|
num_res_blocks: int = 1,
|
|
218
274
|
) -> nn.Sequential:
|
|
219
275
|
"""
|
|
220
|
-
|
|
276
|
+
Method creates the first bottom-up block of the Encoder.
|
|
277
|
+
|
|
221
278
|
Its role is to perform a first image compression step.
|
|
222
279
|
It is composed by a sequence of nn.Conv2d + non-linearity +
|
|
223
280
|
BottomUpDeterministicResBlock (1 or more, default is 1).
|
|
@@ -225,29 +282,30 @@ class LadderVAE(nn.Module):
|
|
|
225
282
|
Parameters
|
|
226
283
|
----------
|
|
227
284
|
init_stride: int
|
|
228
|
-
The stride used by the
|
|
285
|
+
The stride used by the intial Conv2d block.
|
|
229
286
|
num_res_blocks: int, optional
|
|
230
|
-
The number of BottomUpDeterministicResBlocks
|
|
287
|
+
The number of BottomUpDeterministicResBlocks, default is 1.
|
|
231
288
|
"""
|
|
289
|
+
# From what I got from Ashesh, Z should not be touched in any case.
|
|
232
290
|
nonlin = get_activation(self.nonlin)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
else self.encoder_res_block_kernel // 2
|
|
242
|
-
),
|
|
243
|
-
stride=init_stride,
|
|
291
|
+
conv_block = self.encoder_conv_op(
|
|
292
|
+
in_channels=self.color_ch,
|
|
293
|
+
out_channels=self.encoder_n_filters,
|
|
294
|
+
kernel_size=self.encoder_res_block_kernel,
|
|
295
|
+
padding=(
|
|
296
|
+
0
|
|
297
|
+
if self.encoder_res_block_skip_padding
|
|
298
|
+
else self.encoder_res_block_kernel // 2
|
|
244
299
|
),
|
|
245
|
-
|
|
246
|
-
|
|
300
|
+
stride=init_stride,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
modules = [conv_block, nonlin]
|
|
247
304
|
|
|
248
305
|
for _ in range(num_res_blocks):
|
|
249
306
|
modules.append(
|
|
250
307
|
BottomUpDeterministicResBlock(
|
|
308
|
+
conv_strides=self.encoder_conv_strides,
|
|
251
309
|
c_in=self.encoder_n_filters,
|
|
252
310
|
c_out=self.encoder_n_filters,
|
|
253
311
|
nonlin=nonlin,
|
|
@@ -255,7 +313,6 @@ class LadderVAE(nn.Module):
|
|
|
255
313
|
batchnorm=self.bottomup_batchnorm,
|
|
256
314
|
dropout=self.encoder_dropout,
|
|
257
315
|
res_block_type=self.res_block_type,
|
|
258
|
-
skip_padding=self.encoder_res_block_skip_padding,
|
|
259
316
|
res_block_kernel=self.encoder_res_block_kernel,
|
|
260
317
|
)
|
|
261
318
|
)
|
|
@@ -264,7 +321,8 @@ class LadderVAE(nn.Module):
|
|
|
264
321
|
|
|
265
322
|
def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList:
|
|
266
323
|
"""
|
|
267
|
-
|
|
324
|
+
Method creates the stack of bottom-up layers of the Encoder.
|
|
325
|
+
|
|
268
326
|
that are used to generate the so-called `bu_values`.
|
|
269
327
|
|
|
270
328
|
NOTE:
|
|
@@ -274,8 +332,9 @@ class LadderVAE(nn.Module):
|
|
|
274
332
|
Parameters
|
|
275
333
|
----------
|
|
276
334
|
lowres_separate_branch: bool
|
|
277
|
-
Whether the residual block(s) used for encoding the low-res input are shared
|
|
278
|
-
not (`True`) with the "same-size" residual block(s) in the
|
|
335
|
+
Whether the residual block(s) used for encoding the low-res input are shared
|
|
336
|
+
(`False`) or not (`True`) with the "same-size" residual block(s) in the
|
|
337
|
+
`BottomUpLayer`'s primary flow.
|
|
279
338
|
"""
|
|
280
339
|
multiscale_lowres_size_factor = 1
|
|
281
340
|
nonlin = get_activation(self.nonlin)
|
|
@@ -294,11 +353,11 @@ class LadderVAE(nn.Module):
|
|
|
294
353
|
# N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case
|
|
295
354
|
multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale)
|
|
296
355
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
356
|
+
# TODO: check correctness of this
|
|
357
|
+
if self._multiscale_count > 1:
|
|
358
|
+
output_expected_shape = (dim // 2 ** (i + 1) for dim in self.image_size)
|
|
359
|
+
else:
|
|
360
|
+
output_expected_shape = None
|
|
302
361
|
|
|
303
362
|
# Add bottom-up deterministic layer at level i.
|
|
304
363
|
# It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them.
|
|
@@ -308,14 +367,14 @@ class LadderVAE(nn.Module):
|
|
|
308
367
|
n_filters=self.encoder_n_filters,
|
|
309
368
|
downsampling_steps=self.downsample[i],
|
|
310
369
|
nonlin=nonlin,
|
|
370
|
+
conv_strides=self.encoder_conv_strides,
|
|
311
371
|
batchnorm=self.bottomup_batchnorm,
|
|
312
372
|
dropout=self.encoder_dropout,
|
|
313
373
|
res_block_type=self.res_block_type,
|
|
314
374
|
res_block_kernel=self.encoder_res_block_kernel,
|
|
315
|
-
res_block_skip_padding=self.encoder_res_block_skip_padding,
|
|
316
375
|
gated=self.gated,
|
|
317
376
|
lowres_separate_branch=lowres_separate_branch,
|
|
318
|
-
enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here?
|
|
377
|
+
enable_multiscale=self.enable_multiscale, # TODO: shouldn't the arg be `layer_enable_multiscale` here?
|
|
319
378
|
multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims,
|
|
320
379
|
multiscale_lowres_size_factor=multiscale_lowres_size_factor,
|
|
321
380
|
decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
@@ -327,7 +386,8 @@ class LadderVAE(nn.Module):
|
|
|
327
386
|
|
|
328
387
|
def create_top_down_layers(self) -> nn.ModuleList:
|
|
329
388
|
"""
|
|
330
|
-
|
|
389
|
+
Method creates the stack of top-down layers of the Decoder.
|
|
390
|
+
|
|
331
391
|
In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer
|
|
332
392
|
of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution
|
|
333
393
|
with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to
|
|
@@ -346,8 +406,6 @@ class LadderVAE(nn.Module):
|
|
|
346
406
|
When doing unconditional generation, bu_value is not available. Hence the
|
|
347
407
|
merge layer is not used, and z is sampled directly from p_params.
|
|
348
408
|
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
409
|
"""
|
|
352
410
|
top_down_layers = nn.ModuleList([])
|
|
353
411
|
nonlin = get_activation(self.nonlin)
|
|
@@ -356,7 +414,7 @@ class LadderVAE(nn.Module):
|
|
|
356
414
|
# Check if this is the top layer
|
|
357
415
|
is_top = i == self.n_layers - 1
|
|
358
416
|
|
|
359
|
-
if self._enable_topdown_normalize_factor:
|
|
417
|
+
if self._enable_topdown_normalize_factor: # TODO: What is this?
|
|
360
418
|
normalize_latent_factor = (
|
|
361
419
|
1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0
|
|
362
420
|
)
|
|
@@ -369,7 +427,8 @@ class LadderVAE(nn.Module):
|
|
|
369
427
|
n_res_blocks=self.decoder_blocks_per_layer,
|
|
370
428
|
n_filters=self.decoder_n_filters,
|
|
371
429
|
is_top_layer=is_top,
|
|
372
|
-
|
|
430
|
+
conv_strides=self.decoder_conv_strides,
|
|
431
|
+
upsampling_steps=self.downsample[i],
|
|
373
432
|
nonlin=nonlin,
|
|
374
433
|
merge_type=self.merge_type,
|
|
375
434
|
batchnorm=self.topdown_batchnorm,
|
|
@@ -379,17 +438,11 @@ class LadderVAE(nn.Module):
|
|
|
379
438
|
top_prior_param_shape=self.get_top_prior_param_shape(),
|
|
380
439
|
res_block_type=self.res_block_type,
|
|
381
440
|
res_block_kernel=self.decoder_res_block_kernel,
|
|
382
|
-
res_block_skip_padding=self.decoder_res_block_skip_padding,
|
|
383
441
|
gated=self.gated,
|
|
384
442
|
analytical_kl=self.analytical_kl,
|
|
385
|
-
restricted_kl=self._restricted_kl,
|
|
386
443
|
vanilla_latent_hw=self.get_latent_spatial_size(i),
|
|
387
|
-
# in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively.
|
|
388
|
-
bottomup_no_padding_mode=self.encoder_no_padding_mode,
|
|
389
|
-
topdown_no_padding_mode=self.decoder_no_padding_mode,
|
|
390
444
|
retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims,
|
|
391
|
-
|
|
392
|
-
input_image_shape=self.img_shape,
|
|
445
|
+
input_image_shape=self.image_size,
|
|
393
446
|
normalize_latent_factor=normalize_latent_factor,
|
|
394
447
|
conv2d_bias=self.topdown_conv2d_bias,
|
|
395
448
|
stochastic_use_naive_exponential=self._stochastic_use_naive_exponential,
|
|
@@ -398,8 +451,10 @@ class LadderVAE(nn.Module):
|
|
|
398
451
|
return top_down_layers
|
|
399
452
|
|
|
400
453
|
def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential:
|
|
401
|
-
"""
|
|
402
|
-
|
|
454
|
+
"""Create the final top-down layer of the Decoder.
|
|
455
|
+
|
|
456
|
+
NOTE: In this layer, (optional) upsampling is performed by bilinear interpolation
|
|
457
|
+
instead of transposed convolution (like in other TD layers).
|
|
403
458
|
|
|
404
459
|
Parameters
|
|
405
460
|
----------
|
|
@@ -419,69 +474,76 @@ class LadderVAE(nn.Module):
|
|
|
419
474
|
c_in=self.decoder_n_filters,
|
|
420
475
|
c_out=self.decoder_n_filters,
|
|
421
476
|
nonlin=get_activation(self.nonlin),
|
|
477
|
+
conv_strides=self.decoder_conv_strides,
|
|
422
478
|
batchnorm=self.topdown_batchnorm,
|
|
423
479
|
dropout=self.decoder_dropout,
|
|
424
480
|
res_block_type=self.res_block_type,
|
|
425
481
|
res_block_kernel=self.decoder_res_block_kernel,
|
|
426
|
-
skip_padding=self.decoder_res_block_skip_padding,
|
|
427
482
|
gated=self.gated,
|
|
428
483
|
conv2d_bias=self.topdown_conv2d_bias,
|
|
429
484
|
)
|
|
430
485
|
)
|
|
431
486
|
return nn.Sequential(*modules)
|
|
432
487
|
|
|
433
|
-
def _init_multires(
|
|
434
|
-
self, config=None
|
|
435
|
-
) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
|
|
488
|
+
def _init_multires(self, config=None) -> nn.ModuleList:
|
|
436
489
|
"""
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
490
|
+
Method defines the input block/branch to encode/compress low-res lateral inputs.
|
|
491
|
+
|
|
492
|
+
at different hierarchical levels
|
|
493
|
+
in the multiresolution approach (LC). The role of the input branches is similar
|
|
494
|
+
to the one of the first bottom-up layer in the primary flow of the Encoder,
|
|
495
|
+
namely to compress the lateral input image to a degree that is compatible with
|
|
496
|
+
the one of the primary flow.
|
|
497
|
+
|
|
498
|
+
NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity
|
|
499
|
+
+ BottomUpDeterministicResBlock. It is meaningful to observe that the
|
|
500
|
+
`BottomUpDeterministicResBlock` shares the same model attributes with the blocks
|
|
501
|
+
in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.).
|
|
502
|
+
Moreover, it does not perform downsampling.
|
|
503
|
+
|
|
504
|
+
NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the
|
|
505
|
+
bottom-up pass. In other terms if we have the input patch and n_LC additional
|
|
506
|
+
lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
448
507
|
"""
|
|
449
508
|
stride = 1 if self.no_initial_downscaling else 2
|
|
450
509
|
nonlin = get_activation(self.nonlin)
|
|
451
510
|
if self._multiscale_count is None:
|
|
452
511
|
self._multiscale_count = 1
|
|
453
512
|
|
|
454
|
-
msg =
|
|
455
|
-
|
|
513
|
+
msg = (
|
|
514
|
+
f"Multiscale count ({self._multiscale_count}) should not exceed the number"
|
|
515
|
+
f"of bottom up layers ({self.n_layers}) by more than 1.\n"
|
|
516
|
+
)
|
|
456
517
|
assert (
|
|
457
518
|
self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers
|
|
458
|
-
), msg
|
|
519
|
+
), msg # TODO how ?
|
|
459
520
|
|
|
460
521
|
msg = (
|
|
461
|
-
"
|
|
522
|
+
"Multiscale approach only supports monocrome images. "
|
|
523
|
+
f"Found instead color_ch={self.color_ch}."
|
|
462
524
|
)
|
|
463
|
-
assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
525
|
+
# assert self._multiscale_count == 1 or self.color_ch == 1, msg
|
|
464
526
|
|
|
465
527
|
lowres_first_bottom_ups = []
|
|
466
528
|
for _ in range(1, self._multiscale_count):
|
|
467
529
|
first_bottom_up = nn.Sequential(
|
|
468
|
-
|
|
530
|
+
self.encoder_conv_op(
|
|
469
531
|
in_channels=self.color_ch,
|
|
470
532
|
out_channels=self.encoder_n_filters,
|
|
471
533
|
kernel_size=5,
|
|
472
|
-
padding=
|
|
534
|
+
padding="same",
|
|
473
535
|
stride=stride,
|
|
474
536
|
),
|
|
475
537
|
nonlin,
|
|
476
538
|
BottomUpDeterministicResBlock(
|
|
477
539
|
c_in=self.encoder_n_filters,
|
|
478
540
|
c_out=self.encoder_n_filters,
|
|
541
|
+
conv_strides=self.encoder_conv_strides,
|
|
479
542
|
nonlin=nonlin,
|
|
480
543
|
downsample=False,
|
|
481
544
|
batchnorm=self.bottomup_batchnorm,
|
|
482
545
|
dropout=self.encoder_dropout,
|
|
483
546
|
res_block_type=self.res_block_type,
|
|
484
|
-
skip_padding=self.encoder_res_block_skip_padding,
|
|
485
547
|
),
|
|
486
548
|
)
|
|
487
549
|
lowres_first_bottom_ups.append(first_bottom_up)
|
|
@@ -493,10 +555,9 @@ class LadderVAE(nn.Module):
|
|
|
493
555
|
)
|
|
494
556
|
|
|
495
557
|
### SET OF FORWARD-LIKE METHODS
|
|
496
|
-
def bottomup_pass(self, inp: torch.Tensor) ->
|
|
497
|
-
"""
|
|
498
|
-
|
|
499
|
-
"""
|
|
558
|
+
def bottomup_pass(self, inp: torch.Tensor) -> list[torch.Tensor]:
|
|
559
|
+
"""Wrapper of _bottomup_pass()."""
|
|
560
|
+
# TODO Remove wrapper
|
|
500
561
|
return self._bottomup_pass(
|
|
501
562
|
inp,
|
|
502
563
|
self.first_bottom_up,
|
|
@@ -510,9 +571,10 @@ class LadderVAE(nn.Module):
|
|
|
510
571
|
first_bottom_up: nn.Sequential,
|
|
511
572
|
lowres_first_bottom_ups: nn.ModuleList,
|
|
512
573
|
bottom_up_layers: nn.ModuleList,
|
|
513
|
-
) ->
|
|
574
|
+
) -> list[torch.Tensor]:
|
|
514
575
|
"""
|
|
515
|
-
|
|
576
|
+
Method defines the forward pass through the LVAE Encoder, the so-called.
|
|
577
|
+
|
|
516
578
|
Bottom-Up pass.
|
|
517
579
|
|
|
518
580
|
Parameters
|
|
@@ -541,7 +603,6 @@ class LadderVAE(nn.Module):
|
|
|
541
603
|
lowres_x = None
|
|
542
604
|
if self._multiscale_count > 1 and i + 1 < inp.shape[1]:
|
|
543
605
|
lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2])
|
|
544
|
-
|
|
545
606
|
x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x)
|
|
546
607
|
bu_values.append(bu_value)
|
|
547
608
|
|
|
@@ -549,41 +610,40 @@ class LadderVAE(nn.Module):
|
|
|
549
610
|
|
|
550
611
|
def topdown_pass(
|
|
551
612
|
self,
|
|
552
|
-
bu_values: torch.Tensor = None,
|
|
553
|
-
n_img_prior: torch.Tensor = None,
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
613
|
+
bu_values: Union[torch.Tensor, None] = None,
|
|
614
|
+
n_img_prior: Union[torch.Tensor, None] = None,
|
|
615
|
+
constant_layers: Union[Iterable[int], None] = None,
|
|
616
|
+
forced_latent: Union[list[torch.Tensor], None] = None,
|
|
617
|
+
top_down_layers: Union[nn.ModuleList, None] = None,
|
|
618
|
+
final_top_down_layer: Union[nn.Sequential, None] = None,
|
|
619
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
560
620
|
"""
|
|
561
|
-
|
|
621
|
+
Method defines the forward pass through the LVAE Decoder, the so-called.
|
|
622
|
+
|
|
562
623
|
Top-Down pass.
|
|
563
624
|
|
|
564
625
|
Parameters
|
|
565
626
|
----------
|
|
566
627
|
bu_values: torch.Tensor, optional
|
|
567
|
-
Output of the bottom-up pass. It will have values from multiple layers of
|
|
628
|
+
Output of the bottom-up pass. It will have values from multiple layers of
|
|
629
|
+
the ladder.
|
|
568
630
|
n_img_prior: optional
|
|
569
|
-
When `bu_values` is `None`, `n_img_prior` indicates the number of images to
|
|
631
|
+
When `bu_values` is `None`, `n_img_prior` indicates the number of images to
|
|
632
|
+
generate
|
|
570
633
|
from the prior (so bottom-up pass is not used at all here).
|
|
571
|
-
mode_layers: Iterable[int], optional
|
|
572
|
-
A sequence of indexes associated to the layers in which sampling is disabled and
|
|
573
|
-
the mode (mean value) is used instead. Set to `None` to avoid this behaviour.
|
|
574
634
|
constant_layers: Iterable[int], optional
|
|
575
|
-
A sequence of indexes associated to the layers in which a single instance's
|
|
576
|
-
copied over the entire batch (bottom-up path is not used, so only prior
|
|
577
|
-
Set to `None` to avoid this behaviour.
|
|
578
|
-
forced_latent:
|
|
579
|
-
A list of tensors that are used as fixed latent variables (hence, sampling
|
|
580
|
-
place in this case).
|
|
635
|
+
A sequence of indexes associated to the layers in which a single instance's
|
|
636
|
+
z is copied over the entire batch (bottom-up path is not used, so only prior
|
|
637
|
+
is used here). Set to `None` to avoid this behaviour.
|
|
638
|
+
forced_latent: list[torch.Tensor], optional
|
|
639
|
+
A list of tensors that are used as fixed latent variables (hence, sampling
|
|
640
|
+
doesn't take place in this case).
|
|
581
641
|
top_down_layers: nn.ModuleList, optional
|
|
582
|
-
A list of top-down layers to use in the top-down pass. If `None`, the method
|
|
583
|
-
default layers defined in the constructor.
|
|
642
|
+
A list of top-down layers to use in the top-down pass. If `None`, the method
|
|
643
|
+
uses the default layers defined in the constructor.
|
|
584
644
|
final_top_down_layer: nn.Sequential, optional
|
|
585
|
-
The last top-down layer of the top-down pass. If `None`, the method uses the
|
|
586
|
-
layers defined in the constructor.
|
|
645
|
+
The last top-down layer of the top-down pass. If `None`, the method uses the
|
|
646
|
+
default layers defined in the constructor.
|
|
587
647
|
"""
|
|
588
648
|
if top_down_layers is None:
|
|
589
649
|
top_down_layers = self.top_down_layers
|
|
@@ -591,11 +651,9 @@ class LadderVAE(nn.Module):
|
|
|
591
651
|
final_top_down_layer = self.final_top_down
|
|
592
652
|
|
|
593
653
|
# Default: no layer is sampled from the distribution's mode
|
|
594
|
-
if mode_layers is None:
|
|
595
|
-
mode_layers = []
|
|
596
654
|
if constant_layers is None:
|
|
597
655
|
constant_layers = []
|
|
598
|
-
prior_experiment = len(
|
|
656
|
+
prior_experiment = len(constant_layers) > 0
|
|
599
657
|
|
|
600
658
|
# If the bottom-up inference values are not given, don't do
|
|
601
659
|
# inference, sample from prior instead
|
|
@@ -608,11 +666,7 @@ class LadderVAE(nn.Module):
|
|
|
608
666
|
"if and only if we're not doing inference"
|
|
609
667
|
)
|
|
610
668
|
raise RuntimeError(msg)
|
|
611
|
-
if
|
|
612
|
-
inference_mode
|
|
613
|
-
and prior_experiment
|
|
614
|
-
and (self.non_stochastic_version is False)
|
|
615
|
-
):
|
|
669
|
+
if inference_mode and prior_experiment:
|
|
616
670
|
msg = (
|
|
617
671
|
"Prior experiments (e.g. sampling from mode) are not"
|
|
618
672
|
" compatible with inference mode"
|
|
@@ -621,34 +675,24 @@ class LadderVAE(nn.Module):
|
|
|
621
675
|
|
|
622
676
|
# Sampled latent variables at each layer
|
|
623
677
|
z = [None] * self.n_layers
|
|
624
|
-
|
|
625
678
|
# KL divergence of each layer
|
|
626
679
|
kl = [None] * self.n_layers
|
|
627
680
|
# Kl divergence restricted, only for the LC enabled setup denoiSplit.
|
|
628
681
|
kl_restricted = [None] * self.n_layers
|
|
629
|
-
|
|
630
682
|
# mean from which z is sampled.
|
|
631
683
|
q_mu = [None] * self.n_layers
|
|
632
684
|
# log(var) from which z is sampled.
|
|
633
685
|
q_lv = [None] * self.n_layers
|
|
634
|
-
|
|
635
686
|
# Spatial map of KL divergence for each layer
|
|
636
687
|
kl_spatial = [None] * self.n_layers
|
|
637
|
-
|
|
638
688
|
debug_qvar_max = [None] * self.n_layers
|
|
639
|
-
|
|
640
689
|
kl_channelwise = [None] * self.n_layers
|
|
641
|
-
|
|
642
690
|
if forced_latent is None:
|
|
643
691
|
forced_latent = [None] * self.n_layers
|
|
644
692
|
|
|
645
|
-
# log p(z) where z is the sample in the topdown pass
|
|
646
|
-
# logprob_p = 0.
|
|
647
|
-
|
|
648
693
|
# Top-down inference/generation loop
|
|
649
|
-
out =
|
|
694
|
+
out = None
|
|
650
695
|
for i in reversed(range(self.n_layers)):
|
|
651
|
-
|
|
652
696
|
# If available, get deterministic node from bottom-up inference
|
|
653
697
|
try:
|
|
654
698
|
bu_value = bu_values[i]
|
|
@@ -656,26 +700,23 @@ class LadderVAE(nn.Module):
|
|
|
656
700
|
bu_value = None
|
|
657
701
|
|
|
658
702
|
# Whether the current layer should be sampled from the mode
|
|
659
|
-
use_mode = i in mode_layers
|
|
660
703
|
constant_out = i in constant_layers
|
|
661
704
|
|
|
662
705
|
# Input for skip connection
|
|
663
|
-
skip_input = out
|
|
706
|
+
skip_input = out
|
|
664
707
|
|
|
665
708
|
# Full top-down layer, including sampling and deterministic part
|
|
666
|
-
out,
|
|
709
|
+
out, aux = top_down_layers[i](
|
|
667
710
|
input_=out,
|
|
668
711
|
skip_connection_input=skip_input,
|
|
669
712
|
inference_mode=inference_mode,
|
|
670
713
|
bu_value=bu_value,
|
|
671
714
|
n_img_prior=n_img_prior,
|
|
672
|
-
use_mode=use_mode,
|
|
673
715
|
force_constant_output=constant_out,
|
|
674
716
|
forced_latent=forced_latent[i],
|
|
675
717
|
mode_pred=self.mode_pred,
|
|
676
718
|
var_clip_max=self._var_clip_max,
|
|
677
719
|
)
|
|
678
|
-
|
|
679
720
|
# Save useful variables
|
|
680
721
|
z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w)
|
|
681
722
|
kl[i] = aux["kl_samplewise"] # (batch, )
|
|
@@ -708,8 +749,10 @@ class LadderVAE(nn.Module):
|
|
|
708
749
|
}
|
|
709
750
|
return out, data
|
|
710
751
|
|
|
711
|
-
def forward(self, x: torch.Tensor) ->
|
|
752
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
712
753
|
"""
|
|
754
|
+
Forward pass through the LVAE model.
|
|
755
|
+
|
|
713
756
|
Parameters
|
|
714
757
|
----------
|
|
715
758
|
x: torch.Tensor
|
|
@@ -717,124 +760,40 @@ class LadderVAE(nn.Module):
|
|
|
717
760
|
"""
|
|
718
761
|
img_size = x.size()[2:]
|
|
719
762
|
|
|
720
|
-
# Pad input to size equal to the closest power of 2
|
|
721
|
-
x_pad = self.pad_input(x)
|
|
722
|
-
|
|
723
763
|
# Bottom-up inference: return list of length n_layers (bottom to top)
|
|
724
|
-
bu_values = self.bottomup_pass(
|
|
764
|
+
bu_values = self.bottomup_pass(x)
|
|
725
765
|
for i in range(0, self.skip_bottomk_buvalues):
|
|
726
766
|
bu_values[i] = None
|
|
727
767
|
|
|
728
|
-
|
|
768
|
+
if self._squish3d:
|
|
769
|
+
bu_values = [
|
|
770
|
+
torch.mean(self._3D_squisher[k](bu_value), dim=2)
|
|
771
|
+
for k, bu_value in enumerate(bu_values)
|
|
772
|
+
]
|
|
729
773
|
|
|
730
774
|
# Top-down inference/generation
|
|
731
|
-
out, td_data = self.topdown_pass(bu_values
|
|
775
|
+
out, td_data = self.topdown_pass(bu_values)
|
|
732
776
|
|
|
733
777
|
if out.shape[-1] > img_size[-1]:
|
|
734
778
|
# Restore original image size
|
|
735
779
|
out = crop_img_tensor(out, img_size)
|
|
736
780
|
|
|
737
781
|
out = self.output_layer(out)
|
|
738
|
-
if self._tethered_to_input:
|
|
739
|
-
assert out.shape[1] == 1
|
|
740
|
-
ch2 = self.get_other_channel(out, x_pad)
|
|
741
|
-
out = torch.cat([out, ch2], dim=1)
|
|
742
782
|
|
|
743
783
|
return out, td_data
|
|
744
784
|
|
|
745
|
-
### SET OF UTILS METHODS
|
|
746
|
-
# def sample_prior(
|
|
747
|
-
# self,
|
|
748
|
-
# n_imgs,
|
|
749
|
-
# mode_layers=None,
|
|
750
|
-
# constant_layers=None
|
|
751
|
-
# ):
|
|
752
|
-
|
|
753
|
-
# # Generate from prior
|
|
754
|
-
# out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers)
|
|
755
|
-
# out = crop_img_tensor(out, self.img_shape)
|
|
756
|
-
|
|
757
|
-
# # Log likelihood and other info (per data point)
|
|
758
|
-
# _, likelihood_data = self.likelihood(out, None)
|
|
759
|
-
|
|
760
|
-
# return likelihood_data['sample']
|
|
761
|
-
|
|
762
|
-
# ### ???
|
|
763
|
-
# def sample_from_q(self, x, masks=None):
|
|
764
|
-
# """
|
|
765
|
-
# This method performs the bottomup_pass() and samples from the
|
|
766
|
-
# obtained distribution.
|
|
767
|
-
# """
|
|
768
|
-
# img_size = x.size()[2:]
|
|
769
|
-
|
|
770
|
-
# # Pad input to make everything easier with conv strides
|
|
771
|
-
# x_pad = self.pad_input(x)
|
|
772
|
-
|
|
773
|
-
# # Bottom-up inference: return list of length n_layers (bottom to top)
|
|
774
|
-
# bu_values = self.bottomup_pass(x_pad)
|
|
775
|
-
# return self._sample_from_q(bu_values, masks=masks)
|
|
776
|
-
# ### ???
|
|
777
|
-
|
|
778
|
-
# def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None):
|
|
779
|
-
# if top_down_layers is None:
|
|
780
|
-
# top_down_layers = self.top_down_layers
|
|
781
|
-
# if final_top_down_layer is None:
|
|
782
|
-
# final_top_down_layer = self.final_top_down
|
|
783
|
-
# if masks is None:
|
|
784
|
-
# masks = [None] * len(bu_values)
|
|
785
|
-
|
|
786
|
-
# msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this."
|
|
787
|
-
# assert self.n_layers == 1, msg
|
|
788
|
-
# samples = []
|
|
789
|
-
# for i in reversed(range(self.n_layers)):
|
|
790
|
-
# bu_value = bu_values[i]
|
|
791
|
-
|
|
792
|
-
# # Note that the first argument can be set to None since we are just dealing with one level
|
|
793
|
-
# sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i])
|
|
794
|
-
# samples.append(sample)
|
|
795
|
-
|
|
796
|
-
# return samples
|
|
797
|
-
|
|
798
|
-
def reset_for_different_output_size(self, output_size: int) -> None:
|
|
799
|
-
"""Reset shape of output and latent tensors for different output size.
|
|
800
|
-
|
|
801
|
-
Used during evaluation to reset expected shapes of tensors when
|
|
802
|
-
input/output shape changes.
|
|
803
|
-
For instance, it is needed when the model was trained on, say, 64x64 sized
|
|
804
|
-
patches, but prediction is done on 128x128 patches.
|
|
805
|
-
"""
|
|
806
|
-
for i in range(self.n_layers):
|
|
807
|
-
sz = output_size // 2 ** (1 + i)
|
|
808
|
-
self.bottom_up_layers[i].output_expected_shape = (sz, sz)
|
|
809
|
-
self.top_down_layers[i].latent_shape = (output_size, output_size)
|
|
810
|
-
|
|
811
|
-
def pad_input(self, x):
|
|
812
|
-
"""
|
|
813
|
-
Pads input x so that its sizes are powers of 2
|
|
814
|
-
:param x:
|
|
815
|
-
:return: Padded tensor
|
|
816
|
-
"""
|
|
817
|
-
size = self.get_padded_size(x.size())
|
|
818
|
-
x = pad_img_tensor(x, size)
|
|
819
|
-
return x
|
|
820
|
-
|
|
821
785
|
### SET OF GETTERS
|
|
822
786
|
def get_padded_size(self, size):
|
|
823
787
|
"""
|
|
824
788
|
Returns the smallest size (H, W) of the image with actual size given
|
|
825
789
|
as input, such that H and W are powers of 2.
|
|
826
|
-
:param size: input size, tuple either (N, C, H,
|
|
790
|
+
:param size: input size, tuple either (N, C, H, W) or (H, W)
|
|
827
791
|
:return: 2-tuple (H, W)
|
|
828
792
|
"""
|
|
829
793
|
# Make size argument into (heigth, width)
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
msg = (
|
|
834
|
-
"input size must be either (N, C, H, W) or (H, W), but it "
|
|
835
|
-
f"has length {len(size)} (size={size})"
|
|
836
|
-
)
|
|
837
|
-
raise RuntimeError(msg)
|
|
794
|
+
# assert len(size) in [2, 4, 5] # TODO commented out cuz it's weird
|
|
795
|
+
# We're only interested in the Y,X dimensions
|
|
796
|
+
size = size[-2:]
|
|
838
797
|
|
|
839
798
|
if self.multiscale_decoder_retain_spatial_dims is True:
|
|
840
799
|
# In this case, we can go much more deeper and so this is not required
|
|
@@ -845,24 +804,21 @@ class LadderVAE(nn.Module):
|
|
|
845
804
|
dwnsc = self.overall_downscale_factor
|
|
846
805
|
|
|
847
806
|
# Output smallest powers of 2 that are larger than current sizes
|
|
848
|
-
padded_size =
|
|
849
|
-
|
|
807
|
+
padded_size = [((s - 1) // dwnsc + 1) * dwnsc for s in size]
|
|
808
|
+
# TODO Needed for pad/crop odd sizes. Move to dataset?
|
|
850
809
|
return padded_size
|
|
851
810
|
|
|
852
811
|
def get_latent_spatial_size(self, level_idx: int):
|
|
853
|
-
"""
|
|
854
|
-
level_idx: 0 is the bottommost layer, the highest resolution one.
|
|
855
|
-
"""
|
|
812
|
+
"""Level_idx: 0 is the bottommost layer, the highest resolution one."""
|
|
856
813
|
actual_downsampling = level_idx + 1
|
|
857
814
|
dwnsc = 2**actual_downsampling
|
|
858
|
-
sz = self.get_padded_size(self.
|
|
815
|
+
sz = self.get_padded_size(self.image_size)
|
|
859
816
|
h = sz[0] // dwnsc
|
|
860
817
|
w = sz[1] // dwnsc
|
|
861
818
|
assert h == w
|
|
862
819
|
return h
|
|
863
820
|
|
|
864
821
|
def get_top_prior_param_shape(self, n_imgs: int = 1):
|
|
865
|
-
# TODO num channels depends on random variable we're using
|
|
866
822
|
|
|
867
823
|
# Compute the total downscaling performed in the Encoder
|
|
868
824
|
if self.multiscale_decoder_retain_spatial_dims is False:
|
|
@@ -872,26 +828,12 @@ class LadderVAE(nn.Module):
|
|
|
872
828
|
actual_downsampling = self.n_layers + 1 - self._multiscale_count
|
|
873
829
|
dwnsc = 2**actual_downsampling
|
|
874
830
|
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
831
|
+
h = self.image_size[-2] // dwnsc
|
|
832
|
+
w = self.image_size[-1] // dwnsc
|
|
833
|
+
mu_logvar = self.z_dims[-1] * 2 # mu and logvar
|
|
834
|
+
top_layer_shape = (n_imgs, mu_logvar, h, w)
|
|
835
|
+
# TODO refactor!
|
|
836
|
+
if self._model_3D_depth > 1 and self._decoder_mode_3D is True:
|
|
837
|
+
# TODO check if model_3D_depth is needed ?
|
|
838
|
+
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
880
839
|
return top_layer_shape
|
|
881
|
-
|
|
882
|
-
def get_other_channel(self, ch1, input):
|
|
883
|
-
assert self.data_std["target"].squeeze().shape == (2,)
|
|
884
|
-
assert self.data_mean["target"].squeeze().shape == (2,)
|
|
885
|
-
assert self.target_ch == 2
|
|
886
|
-
ch1_un = (
|
|
887
|
-
ch1[:, :1] * self.data_std["target"][:, :1]
|
|
888
|
-
+ self.data_mean["target"][:, :1]
|
|
889
|
-
)
|
|
890
|
-
input_un = input * self.data_std["input"] + self.data_mean["input"]
|
|
891
|
-
ch2_un = self._tethered_ch2_scalar * (
|
|
892
|
-
input_un - ch1_un * self._tethered_ch1_scalar
|
|
893
|
-
)
|
|
894
|
-
ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][
|
|
895
|
-
:, -1:
|
|
896
|
-
]
|
|
897
|
-
return ch2
|