careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
careamics/models/lvae/lvae.py
CHANGED
|
@@ -4,74 +4,73 @@ Ladder VAE (LVAE) Model
|
|
|
4
4
|
The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from typing import Dict, List, Tuple
|
|
8
9
|
|
|
9
|
-
import ml_collections
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
12
|
import torch.nn as nn
|
|
13
13
|
|
|
14
|
+
from careamics.config.architectures import register_model
|
|
15
|
+
|
|
16
|
+
from ..activation import get_activation
|
|
14
17
|
from .layers import (
|
|
15
18
|
BottomUpDeterministicResBlock,
|
|
16
19
|
BottomUpLayer,
|
|
17
20
|
TopDownDeterministicResBlock,
|
|
18
21
|
TopDownLayer,
|
|
19
22
|
)
|
|
20
|
-
from .
|
|
21
|
-
from .noise_models import get_noise_model
|
|
22
|
-
from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
|
|
23
|
+
from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
@register_model("LVAE")
|
|
25
27
|
class LadderVAE(nn.Module):
|
|
26
28
|
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
input_shape: int,
|
|
32
|
+
output_channels: int,
|
|
33
|
+
multiscale_count: int,
|
|
34
|
+
z_dims: List[int],
|
|
35
|
+
encoder_n_filters: int,
|
|
36
|
+
decoder_n_filters: int,
|
|
37
|
+
encoder_dropout: float,
|
|
38
|
+
decoder_dropout: float,
|
|
39
|
+
nonlinearity: str,
|
|
40
|
+
predict_logvar: bool,
|
|
41
|
+
enable_noise_model: bool,
|
|
42
|
+
analytical_kl: bool,
|
|
34
43
|
):
|
|
35
44
|
"""
|
|
36
45
|
Constructor.
|
|
37
46
|
|
|
38
47
|
Parameters
|
|
39
48
|
----------
|
|
40
|
-
|
|
41
|
-
The mean of the data used for normalization.
|
|
42
|
-
data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
|
|
43
|
-
The standard deviation of the data used for normalization.
|
|
44
|
-
config: ml_collections.ConfigDict
|
|
45
|
-
The configuration object of the model.
|
|
46
|
-
use_uncond_mode_at: Iterable[int], optional
|
|
47
|
-
A sequence of indexes associated to the layers in which sampling is disabled
|
|
48
|
-
and the mode (mean value) is used instead. Default is `[]`.
|
|
49
|
-
target_ch: int, optional
|
|
50
|
-
The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
|
|
51
|
-
Default is `2`.
|
|
49
|
+
|
|
52
50
|
"""
|
|
53
51
|
super().__init__()
|
|
54
52
|
|
|
55
53
|
# -------------------------------------------------------
|
|
56
54
|
# Customizable attributes
|
|
57
|
-
self.image_size =
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.
|
|
62
|
-
self.
|
|
63
|
-
self.
|
|
64
|
-
self.
|
|
65
|
-
self.
|
|
66
|
-
self.
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
self.analytical_kl =
|
|
55
|
+
self.image_size = input_shape
|
|
56
|
+
self.target_ch = output_channels
|
|
57
|
+
self._multiscale_count = multiscale_count
|
|
58
|
+
self.z_dims = z_dims
|
|
59
|
+
self.encoder_n_filters = encoder_n_filters
|
|
60
|
+
self.decoder_n_filters = decoder_n_filters
|
|
61
|
+
self.encoder_dropout = encoder_dropout
|
|
62
|
+
self.decoder_dropout = decoder_dropout
|
|
63
|
+
self.nonlin = nonlinearity
|
|
64
|
+
self.predict_logvar = predict_logvar
|
|
65
|
+
self.enable_noise_model = enable_noise_model
|
|
66
|
+
|
|
67
|
+
self.analytical_kl = analytical_kl
|
|
70
68
|
# -------------------------------------------------------
|
|
71
69
|
|
|
72
70
|
# -------------------------------------------------------
|
|
73
71
|
# Model attributes -> Hardcoded
|
|
74
|
-
self.model_type = ModelType.LadderVae
|
|
72
|
+
self.model_type = ModelType.LadderVae # TODO remove !
|
|
73
|
+
self.model_type = ModelType.LadderVae # TODO remove !
|
|
75
74
|
self.encoder_blocks_per_layer = 1
|
|
76
75
|
self.decoder_blocks_per_layer = 1
|
|
77
76
|
self.bottomup_batchnorm = True
|
|
@@ -88,7 +87,7 @@ class LadderVAE(nn.Module):
|
|
|
88
87
|
self.non_stochastic_version = False
|
|
89
88
|
self.stochastic_skip = True
|
|
90
89
|
self.learn_top_prior = True
|
|
91
|
-
self.res_block_type = "bacdbacd"
|
|
90
|
+
self.res_block_type = "bacdbacd" # TODO remove !
|
|
92
91
|
self.mode_pred = False
|
|
93
92
|
self.logvar_lowerbound = -5
|
|
94
93
|
self._var_clip_max = 20
|
|
@@ -151,8 +150,6 @@ class LadderVAE(nn.Module):
|
|
|
151
150
|
self.mixed_rec_w = 0
|
|
152
151
|
self.nbr_consistency_w = 0
|
|
153
152
|
|
|
154
|
-
# Setting the loss_type
|
|
155
|
-
self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
|
|
156
153
|
# -------------------------------------------------------
|
|
157
154
|
|
|
158
155
|
# -------------------------------------------------------
|
|
@@ -167,51 +164,6 @@ class LadderVAE(nn.Module):
|
|
|
167
164
|
# -------------------------------------------------------
|
|
168
165
|
|
|
169
166
|
# -------------------------------------------------------
|
|
170
|
-
# Attributes from constructor arguments
|
|
171
|
-
self.target_ch = target_ch
|
|
172
|
-
self.use_uncond_mode_at = use_uncond_mode_at
|
|
173
|
-
|
|
174
|
-
# Data mean and std used for normalization
|
|
175
|
-
if isinstance(data_mean, np.ndarray):
|
|
176
|
-
self.data_mean = torch.Tensor(data_mean)
|
|
177
|
-
self.data_std = torch.Tensor(data_std)
|
|
178
|
-
elif isinstance(data_mean, dict):
|
|
179
|
-
for k in data_mean.keys():
|
|
180
|
-
data_mean[k] = (
|
|
181
|
-
torch.Tensor(data_mean[k])
|
|
182
|
-
if not isinstance(data_mean[k], dict)
|
|
183
|
-
else data_mean[k]
|
|
184
|
-
)
|
|
185
|
-
data_std[k] = (
|
|
186
|
-
torch.Tensor(data_std[k])
|
|
187
|
-
if not isinstance(data_std[k], dict)
|
|
188
|
-
else data_std[k]
|
|
189
|
-
)
|
|
190
|
-
self.data_mean = data_mean
|
|
191
|
-
self.data_std = data_std
|
|
192
|
-
else:
|
|
193
|
-
raise NotImplementedError(
|
|
194
|
-
"data_mean and data_std must be either a numpy array or a dictionary"
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
assert self.data_std is not None
|
|
198
|
-
assert self.data_mean is not None
|
|
199
|
-
|
|
200
|
-
# Initialize the Noise Model
|
|
201
|
-
self.likelihood_gm = self.likelihood_NM = None
|
|
202
|
-
self.noiseModel = get_noise_model(
|
|
203
|
-
enable_noise_model=self.enable_noise_model,
|
|
204
|
-
model_type=self.model_type,
|
|
205
|
-
noise_model_type=self.noise_model_type,
|
|
206
|
-
noise_model_ch1_fpath=self.noise_model_ch1_fpath,
|
|
207
|
-
noise_model_ch2_fpath=self.noise_model_ch2_fpath,
|
|
208
|
-
noise_model_learnable=self.noise_model_learnable,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
if self.noiseModel is None:
|
|
212
|
-
self.likelihood_form = "gaussian"
|
|
213
|
-
else:
|
|
214
|
-
self.likelihood_form = "noise_model"
|
|
215
167
|
|
|
216
168
|
# Calculate the downsampling happening in the network
|
|
217
169
|
self.downsample = [1] * self.n_layers
|
|
@@ -246,7 +198,7 @@ class LadderVAE(nn.Module):
|
|
|
246
198
|
)
|
|
247
199
|
|
|
248
200
|
# Likelihood module
|
|
249
|
-
self.likelihood = self.create_likelihood_module()
|
|
201
|
+
# self.likelihood = self.create_likelihood_module()
|
|
250
202
|
|
|
251
203
|
# Output layer --> Project to target_ch many channels
|
|
252
204
|
logvar_ch_needed = self.predict_logvar is not None
|
|
@@ -284,11 +236,11 @@ class LadderVAE(nn.Module):
|
|
|
284
236
|
Parameters
|
|
285
237
|
----------
|
|
286
238
|
init_stride: int
|
|
287
|
-
The stride used by the
|
|
239
|
+
The stride used by the initial Conv2d block.
|
|
288
240
|
num_res_blocks: int, optional
|
|
289
241
|
The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
|
|
290
242
|
"""
|
|
291
|
-
nonlin = self.
|
|
243
|
+
nonlin = get_activation(self.nonlin)
|
|
292
244
|
modules = [
|
|
293
245
|
nn.Conv2d(
|
|
294
246
|
in_channels=self.color_ch,
|
|
@@ -301,7 +253,7 @@ class LadderVAE(nn.Module):
|
|
|
301
253
|
),
|
|
302
254
|
stride=init_stride,
|
|
303
255
|
),
|
|
304
|
-
nonlin
|
|
256
|
+
nonlin,
|
|
305
257
|
]
|
|
306
258
|
|
|
307
259
|
for _ in range(num_res_blocks):
|
|
@@ -337,7 +289,7 @@ class LadderVAE(nn.Module):
|
|
|
337
289
|
not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
|
|
338
290
|
"""
|
|
339
291
|
multiscale_lowres_size_factor = 1
|
|
340
|
-
nonlin = self.
|
|
292
|
+
nonlin = get_activation(self.nonlin)
|
|
341
293
|
|
|
342
294
|
bottom_up_layers = nn.ModuleList([])
|
|
343
295
|
for i in range(self.n_layers):
|
|
@@ -409,7 +361,7 @@ class LadderVAE(nn.Module):
|
|
|
409
361
|
----------
|
|
410
362
|
"""
|
|
411
363
|
top_down_layers = nn.ModuleList([])
|
|
412
|
-
nonlin = self.
|
|
364
|
+
nonlin = get_activation(self.nonlin)
|
|
413
365
|
# NOTE: top-down layers are created starting from the bottom-most
|
|
414
366
|
for i in range(self.n_layers):
|
|
415
367
|
# Check if this is the top layer
|
|
@@ -477,7 +429,7 @@ class LadderVAE(nn.Module):
|
|
|
477
429
|
TopDownDeterministicResBlock(
|
|
478
430
|
c_in=self.decoder_n_filters,
|
|
479
431
|
c_out=self.decoder_n_filters,
|
|
480
|
-
nonlin=self.
|
|
432
|
+
nonlin=get_activation(self.nonlin),
|
|
481
433
|
batchnorm=self.topdown_batchnorm,
|
|
482
434
|
dropout=self.decoder_dropout,
|
|
483
435
|
res_block_type=self.res_block_type,
|
|
@@ -489,34 +441,9 @@ class LadderVAE(nn.Module):
|
|
|
489
441
|
)
|
|
490
442
|
return nn.Sequential(*modules)
|
|
491
443
|
|
|
492
|
-
def
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
|
|
496
|
-
"""
|
|
497
|
-
self.likelihood_gm = GaussianLikelihood(
|
|
498
|
-
self.decoder_n_filters,
|
|
499
|
-
self.target_ch,
|
|
500
|
-
predict_logvar=self.predict_logvar,
|
|
501
|
-
logvar_lowerbound=self.logvar_lowerbound,
|
|
502
|
-
conv2d_bias=self.topdown_conv2d_bias,
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
self.likelihood_NM = None
|
|
506
|
-
if self.enable_noise_model:
|
|
507
|
-
self.likelihood_NM = NoiseModelLikelihood(
|
|
508
|
-
self.decoder_n_filters,
|
|
509
|
-
self.target_ch,
|
|
510
|
-
self.data_mean,
|
|
511
|
-
self.data_std,
|
|
512
|
-
self.noiseModel,
|
|
513
|
-
)
|
|
514
|
-
if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
|
|
515
|
-
return self.likelihood_gm
|
|
516
|
-
|
|
517
|
-
return self.likelihood_NM
|
|
518
|
-
|
|
519
|
-
def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
|
|
444
|
+
def _init_multires(
|
|
445
|
+
self, config=None
|
|
446
|
+
) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
|
|
520
447
|
"""
|
|
521
448
|
This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
|
|
522
449
|
in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
|
|
@@ -531,7 +458,7 @@ class LadderVAE(nn.Module):
|
|
|
531
458
|
In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
|
|
532
459
|
"""
|
|
533
460
|
stride = 1 if self.no_initial_downscaling else 2
|
|
534
|
-
nonlin = self.
|
|
461
|
+
nonlin = get_activation(self.nonlin)
|
|
535
462
|
if self._multiscale_count is None:
|
|
536
463
|
self._multiscale_count = 1
|
|
537
464
|
|
|
@@ -556,7 +483,7 @@ class LadderVAE(nn.Module):
|
|
|
556
483
|
padding=2,
|
|
557
484
|
stride=stride,
|
|
558
485
|
),
|
|
559
|
-
nonlin
|
|
486
|
+
nonlin,
|
|
560
487
|
BottomUpDeterministicResBlock(
|
|
561
488
|
c_in=self.encoder_n_filters,
|
|
562
489
|
c_out=self.encoder_n_filters,
|
|
@@ -596,7 +523,7 @@ class LadderVAE(nn.Module):
|
|
|
596
523
|
bottom_up_layers: nn.ModuleList,
|
|
597
524
|
) -> List[torch.Tensor]:
|
|
598
525
|
"""
|
|
599
|
-
This method defines the forward pass
|
|
526
|
+
This method defines the forward pass through the LVAE Encoder, the so-called
|
|
600
527
|
Bottom-Up pass.
|
|
601
528
|
|
|
602
529
|
Parameters
|
|
@@ -642,7 +569,7 @@ class LadderVAE(nn.Module):
|
|
|
642
569
|
final_top_down_layer: nn.Sequential = None,
|
|
643
570
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
644
571
|
"""
|
|
645
|
-
This method defines the forward pass
|
|
572
|
+
This method defines the forward pass through the LVAE Decoder, the so-called
|
|
646
573
|
Top-Down pass.
|
|
647
574
|
|
|
648
575
|
Parameters
|
|
@@ -664,10 +591,10 @@ class LadderVAE(nn.Module):
|
|
|
664
591
|
place in this case).
|
|
665
592
|
top_down_layers: nn.ModuleList, optional
|
|
666
593
|
A list of top-down layers to use in the top-down pass. If `None`, the method uses the
|
|
667
|
-
default layers defined in the
|
|
594
|
+
default layers defined in the constructor.
|
|
668
595
|
final_top_down_layer: nn.Sequential, optional
|
|
669
596
|
The last top-down layer of the top-down pass. If `None`, the method uses the default
|
|
670
|
-
layers defined in the
|
|
597
|
+
layers defined in the constructor.
|
|
671
598
|
"""
|
|
672
599
|
if top_down_layers is None:
|
|
673
600
|
top_down_layers = self.top_down_layers
|
|
@@ -742,7 +669,6 @@ class LadderVAE(nn.Module):
|
|
|
742
669
|
# Whether the current layer should be sampled from the mode
|
|
743
670
|
use_mode = i in mode_layers
|
|
744
671
|
constant_out = i in constant_layers
|
|
745
|
-
use_uncond_mode = i in self.use_uncond_mode_at
|
|
746
672
|
|
|
747
673
|
# Input for skip connection
|
|
748
674
|
skip_input = out # TODO or n? or both?
|
|
@@ -758,7 +684,6 @@ class LadderVAE(nn.Module):
|
|
|
758
684
|
force_constant_output=constant_out,
|
|
759
685
|
forced_latent=forced_latent[i],
|
|
760
686
|
mode_pred=self.mode_pred,
|
|
761
|
-
use_uncond_mode=use_uncond_mode,
|
|
762
687
|
var_clip_max=self._var_clip_max,
|
|
763
688
|
)
|
|
764
689
|
|
|
@@ -898,15 +823,6 @@ class LadderVAE(nn.Module):
|
|
|
898
823
|
return x
|
|
899
824
|
|
|
900
825
|
### SET OF GETTERS
|
|
901
|
-
def get_nonlin(self):
|
|
902
|
-
nonlin = {
|
|
903
|
-
"relu": nn.ReLU,
|
|
904
|
-
"leakyrelu": nn.LeakyReLU,
|
|
905
|
-
"elu": nn.ELU,
|
|
906
|
-
"selu": nn.SELU,
|
|
907
|
-
}
|
|
908
|
-
return nonlin[self.nonlin]
|
|
909
|
-
|
|
910
826
|
def get_padded_size(self, size):
|
|
911
827
|
"""
|
|
912
828
|
Returns the smallest size (H, W) of the image with actual size given
|