careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
careamics/models/lvae/lvae.py
CHANGED
|
@@ -4,74 +4,69 @@ Ladder VAE (LVAE) Model
|
|
|
4
4
|
The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from typing import Dict, List, Tuple
|
|
8
9
|
|
|
9
|
-
import ml_collections
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
12
|
import torch.nn as nn
|
|
13
13
|
|
|
14
|
+
from careamics.config.architectures import register_model
|
|
15
|
+
|
|
16
|
+
from ..activation import get_activation
|
|
14
17
|
from .layers import (
|
|
15
18
|
BottomUpDeterministicResBlock,
|
|
16
19
|
BottomUpLayer,
|
|
17
20
|
TopDownDeterministicResBlock,
|
|
18
21
|
TopDownLayer,
|
|
19
22
|
)
|
|
20
|
-
from .
|
|
21
|
-
from .noise_models import get_noise_model
|
|
22
|
-
from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
|
|
23
|
+
from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
@register_model("LVAE")
|
|
25
27
|
class LadderVAE(nn.Module):
|
|
26
28
|
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
input_shape: int,
|
|
32
|
+
output_channels: int,
|
|
33
|
+
multiscale_count: int,
|
|
34
|
+
z_dims: List[int],
|
|
35
|
+
encoder_n_filters: int,
|
|
36
|
+
decoder_n_filters: int,
|
|
37
|
+
encoder_dropout: float,
|
|
38
|
+
decoder_dropout: float,
|
|
39
|
+
nonlinearity: str,
|
|
40
|
+
predict_logvar: bool,
|
|
41
|
+
analytical_kl: bool,
|
|
34
42
|
):
|
|
35
43
|
"""
|
|
36
44
|
Constructor.
|
|
37
45
|
|
|
38
46
|
Parameters
|
|
39
47
|
----------
|
|
40
|
-
|
|
41
|
-
The mean of the data used for normalization.
|
|
42
|
-
data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
|
|
43
|
-
The standard deviation of the data used for normalization.
|
|
44
|
-
config: ml_collections.ConfigDict
|
|
45
|
-
The configuration object of the model.
|
|
46
|
-
use_uncond_mode_at: Iterable[int], optional
|
|
47
|
-
A sequence of indexes associated to the layers in which sampling is disabled
|
|
48
|
-
and the mode (mean value) is used instead. Default is `[]`.
|
|
49
|
-
target_ch: int, optional
|
|
50
|
-
The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
|
|
51
|
-
Default is `2`.
|
|
48
|
+
|
|
52
49
|
"""
|
|
53
50
|
super().__init__()
|
|
54
51
|
|
|
55
52
|
# -------------------------------------------------------
|
|
56
53
|
# Customizable attributes
|
|
57
|
-
self.image_size =
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.
|
|
62
|
-
self.
|
|
63
|
-
self.
|
|
64
|
-
self.
|
|
65
|
-
self.
|
|
66
|
-
self.
|
|
67
|
-
self.
|
|
68
|
-
self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath
|
|
69
|
-
self.analytical_kl = config.model.analytical_kl
|
|
54
|
+
self.image_size = input_shape
|
|
55
|
+
self.target_ch = output_channels
|
|
56
|
+
self._multiscale_count = multiscale_count
|
|
57
|
+
self.z_dims = z_dims
|
|
58
|
+
self.encoder_n_filters = encoder_n_filters
|
|
59
|
+
self.decoder_n_filters = decoder_n_filters
|
|
60
|
+
self.encoder_dropout = encoder_dropout
|
|
61
|
+
self.decoder_dropout = decoder_dropout
|
|
62
|
+
self.nonlin = nonlinearity
|
|
63
|
+
self.predict_logvar = predict_logvar
|
|
64
|
+
self.analytical_kl = analytical_kl
|
|
70
65
|
# -------------------------------------------------------
|
|
71
66
|
|
|
72
67
|
# -------------------------------------------------------
|
|
73
68
|
# Model attributes -> Hardcoded
|
|
74
|
-
self.model_type = ModelType.LadderVae
|
|
69
|
+
self.model_type = ModelType.LadderVae # TODO remove !
|
|
75
70
|
self.encoder_blocks_per_layer = 1
|
|
76
71
|
self.decoder_blocks_per_layer = 1
|
|
77
72
|
self.bottomup_batchnorm = True
|
|
@@ -88,20 +83,13 @@ class LadderVAE(nn.Module):
|
|
|
88
83
|
self.non_stochastic_version = False
|
|
89
84
|
self.stochastic_skip = True
|
|
90
85
|
self.learn_top_prior = True
|
|
91
|
-
self.res_block_type = "bacdbacd"
|
|
86
|
+
self.res_block_type = "bacdbacd" # TODO remove !
|
|
92
87
|
self.mode_pred = False
|
|
93
88
|
self.logvar_lowerbound = -5
|
|
94
89
|
self._var_clip_max = 20
|
|
95
90
|
self._stochastic_use_naive_exponential = False
|
|
96
91
|
self._enable_topdown_normalize_factor = True
|
|
97
92
|
|
|
98
|
-
# Noise model attributes -> Hardcoded
|
|
99
|
-
self.noise_model_type = "gmm"
|
|
100
|
-
self.denoise_channel = (
|
|
101
|
-
"input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
|
|
102
|
-
)
|
|
103
|
-
self.noise_model_learnable = False
|
|
104
|
-
|
|
105
93
|
# Attributes that handle LC -> Hardcoded
|
|
106
94
|
self.enable_multiscale = (
|
|
107
95
|
self._multiscale_count is not None and self._multiscale_count > 1
|
|
@@ -151,8 +139,6 @@ class LadderVAE(nn.Module):
|
|
|
151
139
|
self.mixed_rec_w = 0
|
|
152
140
|
self.nbr_consistency_w = 0
|
|
153
141
|
|
|
154
|
-
# Setting the loss_type
|
|
155
|
-
self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
|
|
156
142
|
# -------------------------------------------------------
|
|
157
143
|
|
|
158
144
|
# -------------------------------------------------------
|
|
@@ -167,51 +153,6 @@ class LadderVAE(nn.Module):
|
|
|
167
153
|
# -------------------------------------------------------
|
|
168
154
|
|
|
169
155
|
# -------------------------------------------------------
|
|
170
|
-
# Attributes from constructor arguments
|
|
171
|
-
self.target_ch = target_ch
|
|
172
|
-
self.use_uncond_mode_at = use_uncond_mode_at
|
|
173
|
-
|
|
174
|
-
# Data mean and std used for normalization
|
|
175
|
-
if isinstance(data_mean, np.ndarray):
|
|
176
|
-
self.data_mean = torch.Tensor(data_mean)
|
|
177
|
-
self.data_std = torch.Tensor(data_std)
|
|
178
|
-
elif isinstance(data_mean, dict):
|
|
179
|
-
for k in data_mean.keys():
|
|
180
|
-
data_mean[k] = (
|
|
181
|
-
torch.Tensor(data_mean[k])
|
|
182
|
-
if not isinstance(data_mean[k], dict)
|
|
183
|
-
else data_mean[k]
|
|
184
|
-
)
|
|
185
|
-
data_std[k] = (
|
|
186
|
-
torch.Tensor(data_std[k])
|
|
187
|
-
if not isinstance(data_std[k], dict)
|
|
188
|
-
else data_std[k]
|
|
189
|
-
)
|
|
190
|
-
self.data_mean = data_mean
|
|
191
|
-
self.data_std = data_std
|
|
192
|
-
else:
|
|
193
|
-
raise NotImplementedError(
|
|
194
|
-
"data_mean and data_std must be either a numpy array or a dictionary"
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
assert self.data_std is not None
|
|
198
|
-
assert self.data_mean is not None
|
|
199
|
-
|
|
200
|
-
# Initialize the Noise Model
|
|
201
|
-
self.likelihood_gm = self.likelihood_NM = None
|
|
202
|
-
self.noiseModel = get_noise_model(
|
|
203
|
-
enable_noise_model=self.enable_noise_model,
|
|
204
|
-
model_type=self.model_type,
|
|
205
|
-
noise_model_type=self.noise_model_type,
|
|
206
|
-
noise_model_ch1_fpath=self.noise_model_ch1_fpath,
|
|
207
|
-
noise_model_ch2_fpath=self.noise_model_ch2_fpath,
|
|
208
|
-
noise_model_learnable=self.noise_model_learnable,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
if self.noiseModel is None:
|
|
212
|
-
self.likelihood_form = "gaussian"
|
|
213
|
-
else:
|
|
214
|
-
self.likelihood_form = "noise_model"
|
|
215
156
|
|
|
216
157
|
# Calculate the downsampling happening in the network
|
|
217
158
|
self.downsample = [1] * self.n_layers
|
|
@@ -246,7 +187,7 @@ class LadderVAE(nn.Module):
|
|
|
246
187
|
)
|
|
247
188
|
|
|
248
189
|
# Likelihood module
|
|
249
|
-
self.likelihood = self.create_likelihood_module()
|
|
190
|
+
# self.likelihood = self.create_likelihood_module()
|
|
250
191
|
|
|
251
192
|
# Output layer --> Project to target_ch many channels
|
|
252
193
|
logvar_ch_needed = self.predict_logvar is not None
|
|
@@ -284,11 +225,11 @@ class LadderVAE(nn.Module):
|
|
|
284
225
|
Parameters
|
|
285
226
|
----------
|
|
286
227
|
init_stride: int
|
|
287
|
-
The stride used by the
|
|
228
|
+
The stride used by the initial Conv2d block.
|
|
288
229
|
num_res_blocks: int, optional
|
|
289
230
|
The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
|
|
290
231
|
"""
|
|
291
|
-
nonlin = self.
|
|
232
|
+
nonlin = get_activation(self.nonlin)
|
|
292
233
|
modules = [
|
|
293
234
|
nn.Conv2d(
|
|
294
235
|
in_channels=self.color_ch,
|
|
@@ -301,7 +242,7 @@ class LadderVAE(nn.Module):
|
|
|
301
242
|
),
|
|
302
243
|
stride=init_stride,
|
|
303
244
|
),
|
|
304
|
-
nonlin
|
|
245
|
+
nonlin,
|
|
305
246
|
]
|
|
306
247
|
|
|
307
248
|
for _ in range(num_res_blocks):
|
|
@@ -337,7 +278,7 @@ class LadderVAE(nn.Module):
|
|
|
337
278
|
not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
|
|
338
279
|
"""
|
|
339
280
|
multiscale_lowres_size_factor = 1
|
|
340
|
-
nonlin = self.
|
|
281
|
+
nonlin = get_activation(self.nonlin)
|
|
341
282
|
|
|
342
283
|
bottom_up_layers = nn.ModuleList([])
|
|
343
284
|
for i in range(self.n_layers):
|
|
@@ -409,7 +350,7 @@ class LadderVAE(nn.Module):
|
|
|
409
350
|
----------
|
|
410
351
|
"""
|
|
411
352
|
top_down_layers = nn.ModuleList([])
|
|
412
|
-
nonlin = self.
|
|
353
|
+
nonlin = get_activation(self.nonlin)
|
|
413
354
|
# NOTE: top-down layers are created starting from the bottom-most
|
|
414
355
|
for i in range(self.n_layers):
|
|
415
356
|
# Check if this is the top layer
|
|
@@ -477,7 +418,7 @@ class LadderVAE(nn.Module):
|
|
|
477
418
|
TopDownDeterministicResBlock(
|
|
478
419
|
c_in=self.decoder_n_filters,
|
|
479
420
|
c_out=self.decoder_n_filters,
|
|
480
|
-
nonlin=self.
|
|
421
|
+
nonlin=get_activation(self.nonlin),
|
|
481
422
|
batchnorm=self.topdown_batchnorm,
|
|
482
423
|
dropout=self.decoder_dropout,
|
|
483
424
|
res_block_type=self.res_block_type,
|
|
@@ -489,34 +430,9 @@ class LadderVAE(nn.Module):
|
|
|
489
430
|
)
|
|
490
431
|
return nn.Sequential(*modules)
|
|
491
432
|
|
|
492
|
-
def
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
|
|
496
|
-
"""
|
|
497
|
-
self.likelihood_gm = GaussianLikelihood(
|
|
498
|
-
self.decoder_n_filters,
|
|
499
|
-
self.target_ch,
|
|
500
|
-
predict_logvar=self.predict_logvar,
|
|
501
|
-
logvar_lowerbound=self.logvar_lowerbound,
|
|
502
|
-
conv2d_bias=self.topdown_conv2d_bias,
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
self.likelihood_NM = None
|
|
506
|
-
if self.enable_noise_model:
|
|
507
|
-
self.likelihood_NM = NoiseModelLikelihood(
|
|
508
|
-
self.decoder_n_filters,
|
|
509
|
-
self.target_ch,
|
|
510
|
-
self.data_mean,
|
|
511
|
-
self.data_std,
|
|
512
|
-
self.noiseModel,
|
|
513
|
-
)
|
|
514
|
-
if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
|
|
515
|
-
return self.likelihood_gm
|
|
516
|
-
|
|
517
|
-
return self.likelihood_NM
|
|
518
|
-
|
|
519
|
-
def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
|
|
433
|
+
def _init_multires(
|
|
434
|
+
self, config=None
|
|
435
|
+
) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
|
|
520
436
|
"""
|
|
521
437
|
This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
|
|
522
438
|
in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
|
|
@@ -531,7 +447,7 @@ class LadderVAE(nn.Module):
|
|
|
531
447
|
In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
532
448
|
"""
|
|
533
449
|
stride = 1 if self.no_initial_downscaling else 2
|
|
534
|
-
nonlin = self.
|
|
450
|
+
nonlin = get_activation(self.nonlin)
|
|
535
451
|
if self._multiscale_count is None:
|
|
536
452
|
self._multiscale_count = 1
|
|
537
453
|
|
|
@@ -556,7 +472,7 @@ class LadderVAE(nn.Module):
|
|
|
556
472
|
padding=2,
|
|
557
473
|
stride=stride,
|
|
558
474
|
),
|
|
559
|
-
nonlin
|
|
475
|
+
nonlin,
|
|
560
476
|
BottomUpDeterministicResBlock(
|
|
561
477
|
c_in=self.encoder_n_filters,
|
|
562
478
|
c_out=self.encoder_n_filters,
|
|
@@ -596,7 +512,7 @@ class LadderVAE(nn.Module):
|
|
|
596
512
|
bottom_up_layers: nn.ModuleList,
|
|
597
513
|
) -> List[torch.Tensor]:
|
|
598
514
|
"""
|
|
599
|
-
This method defines the forward pass
|
|
515
|
+
This method defines the forward pass through the LVAE Encoder, the so-called
|
|
600
516
|
Bottom-Up pass.
|
|
601
517
|
|
|
602
518
|
Parameters
|
|
@@ -642,7 +558,7 @@ class LadderVAE(nn.Module):
|
|
|
642
558
|
final_top_down_layer: nn.Sequential = None,
|
|
643
559
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
644
560
|
"""
|
|
645
|
-
This method defines the forward pass
|
|
561
|
+
This method defines the forward pass through the LVAE Decoder, the so-called
|
|
646
562
|
Top-Down pass.
|
|
647
563
|
|
|
648
564
|
Parameters
|
|
@@ -664,10 +580,10 @@ class LadderVAE(nn.Module):
|
|
|
664
580
|
place in this case).
|
|
665
581
|
top_down_layers: nn.ModuleList, optional
|
|
666
582
|
A list of top-down layers to use in the top-down pass. If `None`, the method uses the
|
|
667
|
-
default layers defined in the
|
|
583
|
+
default layers defined in the constructor.
|
|
668
584
|
final_top_down_layer: nn.Sequential, optional
|
|
669
585
|
The last top-down layer of the top-down pass. If `None`, the method uses the default
|
|
670
|
-
layers defined in the
|
|
586
|
+
layers defined in the constructor.
|
|
671
587
|
"""
|
|
672
588
|
if top_down_layers is None:
|
|
673
589
|
top_down_layers = self.top_down_layers
|
|
@@ -742,7 +658,6 @@ class LadderVAE(nn.Module):
|
|
|
742
658
|
# Whether the current layer should be sampled from the mode
|
|
743
659
|
use_mode = i in mode_layers
|
|
744
660
|
constant_out = i in constant_layers
|
|
745
|
-
use_uncond_mode = i in self.use_uncond_mode_at
|
|
746
661
|
|
|
747
662
|
# Input for skip connection
|
|
748
663
|
skip_input = out # TODO or n? or both?
|
|
@@ -758,7 +673,6 @@ class LadderVAE(nn.Module):
|
|
|
758
673
|
force_constant_output=constant_out,
|
|
759
674
|
forced_latent=forced_latent[i],
|
|
760
675
|
mode_pred=self.mode_pred,
|
|
761
|
-
use_uncond_mode=use_uncond_mode,
|
|
762
676
|
var_clip_max=self._var_clip_max,
|
|
763
677
|
)
|
|
764
678
|
|
|
@@ -881,11 +795,18 @@ class LadderVAE(nn.Module):
|
|
|
881
795
|
|
|
882
796
|
# return samples
|
|
883
797
|
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
798
|
+
def reset_for_different_output_size(self, output_size: int) -> None:
|
|
799
|
+
"""Reset shape of output and latent tensors for different output size.
|
|
800
|
+
|
|
801
|
+
Used during evaluation to reset expected shapes of tensors when
|
|
802
|
+
input/output shape changes.
|
|
803
|
+
For instance, it is needed when the model was trained on, say, 64x64 sized
|
|
804
|
+
patches, but prediction is done on 128x128 patches.
|
|
805
|
+
"""
|
|
806
|
+
for i in range(self.n_layers):
|
|
807
|
+
sz = output_size // 2 ** (1 + i)
|
|
808
|
+
self.bottom_up_layers[i].output_expected_shape = (sz, sz)
|
|
809
|
+
self.top_down_layers[i].latent_shape = (output_size, output_size)
|
|
889
810
|
|
|
890
811
|
def pad_input(self, x):
|
|
891
812
|
"""
|
|
@@ -898,15 +819,6 @@ class LadderVAE(nn.Module):
|
|
|
898
819
|
return x
|
|
899
820
|
|
|
900
821
|
### SET OF GETTERS
|
|
901
|
-
def get_nonlin(self):
|
|
902
|
-
nonlin = {
|
|
903
|
-
"relu": nn.ReLU,
|
|
904
|
-
"leakyrelu": nn.LeakyReLU,
|
|
905
|
-
"elu": nn.ELU,
|
|
906
|
-
"selu": nn.SELU,
|
|
907
|
-
}
|
|
908
|
-
return nonlin[self.nonlin]
|
|
909
|
-
|
|
910
822
|
def get_padded_size(self, size):
|
|
911
823
|
"""
|
|
912
824
|
Returns the smallest size (H, W) of the image with actual size given
|