careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -4,74 +4,73 @@ Ladder VAE (LVAE) Model
4
4
  The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
5
5
  """
6
6
 
7
- from typing import Dict, Iterable, List, Tuple, Union
7
+ from collections.abc import Iterable
8
+ from typing import Dict, List, Tuple
8
9
 
9
- import ml_collections
10
10
  import numpy as np
11
11
  import torch
12
12
  import torch.nn as nn
13
13
 
14
+ from careamics.config.architectures import register_model
15
+
16
+ from ..activation import get_activation
14
17
  from .layers import (
15
18
  BottomUpDeterministicResBlock,
16
19
  BottomUpLayer,
17
20
  TopDownDeterministicResBlock,
18
21
  TopDownLayer,
19
22
  )
20
- from .likelihoods import GaussianLikelihood, NoiseModelLikelihood
21
- from .noise_models import get_noise_model
22
- from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
23
+ from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
23
24
 
24
25
 
26
+ @register_model("LVAE")
25
27
  class LadderVAE(nn.Module):
26
28
 
27
29
  def __init__(
28
30
  self,
29
- data_mean: Union[np.ndarray, Dict[str, torch.Tensor]],
30
- data_std: Union[np.ndarray, Dict[str, torch.Tensor]],
31
- config: ml_collections.ConfigDict,
32
- use_uncond_mode_at: Iterable[int] = [],
33
- target_ch: int = 2,
31
+ input_shape: int,
32
+ output_channels: int,
33
+ multiscale_count: int,
34
+ z_dims: List[int],
35
+ encoder_n_filters: int,
36
+ decoder_n_filters: int,
37
+ encoder_dropout: float,
38
+ decoder_dropout: float,
39
+ nonlinearity: str,
40
+ predict_logvar: bool,
41
+ enable_noise_model: bool,
42
+ analytical_kl: bool,
34
43
  ):
35
44
  """
36
45
  Constructor.
37
46
 
38
47
  Parameters
39
48
  ----------
40
- data_mean: Union[np.ndarray, Dict[str, torch.Tensor]]
41
- The mean of the data used for normalization.
42
- data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
43
- The standard deviation of the data used for normalization.
44
- config: ml_collections.ConfigDict
45
- The configuration object of the model.
46
- use_uncond_mode_at: Iterable[int], optional
47
- A sequence of indexes associated to the layers in which sampling is disabled
48
- and the mode (mean value) is used instead. Default is `[]`.
49
- target_ch: int, optional
50
- The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
51
- Default is `2`.
49
+
52
50
  """
53
51
  super().__init__()
54
52
 
55
53
  # -------------------------------------------------------
56
54
  # Customizable attributes
57
- self.image_size = config.data.image_size
58
- self._multiscale_count = config.data.multiscale_lowres_count
59
- self.z_dims = config.model.z_dims
60
- self.encoder_n_filters = config.model.n_filters
61
- self.decoder_n_filters = config.model.n_filters
62
- self.encoder_dropout = config.model.dropout
63
- self.decoder_dropout = config.model.dropout
64
- self.nonlin = config.model.nonlin
65
- self.predict_logvar = config.model.predict_logvar
66
- self.enable_noise_model = config.model.enable_noise_model
67
- self.noise_model_ch1_fpath = config.model.noise_model_ch1_fpath
68
- self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath
69
- self.analytical_kl = config.model.analytical_kl
55
+ self.image_size = input_shape
56
+ self.target_ch = output_channels
57
+ self._multiscale_count = multiscale_count
58
+ self.z_dims = z_dims
59
+ self.encoder_n_filters = encoder_n_filters
60
+ self.decoder_n_filters = decoder_n_filters
61
+ self.encoder_dropout = encoder_dropout
62
+ self.decoder_dropout = decoder_dropout
63
+ self.nonlin = nonlinearity
64
+ self.predict_logvar = predict_logvar
65
+ self.enable_noise_model = enable_noise_model
66
+
67
+ self.analytical_kl = analytical_kl
70
68
  # -------------------------------------------------------
71
69
 
72
70
  # -------------------------------------------------------
73
71
  # Model attributes -> Hardcoded
74
- self.model_type = ModelType.LadderVae
72
+ self.model_type = ModelType.LadderVae # TODO remove !
73
+ self.model_type = ModelType.LadderVae # TODO remove !
75
74
  self.encoder_blocks_per_layer = 1
76
75
  self.decoder_blocks_per_layer = 1
77
76
  self.bottomup_batchnorm = True
@@ -88,7 +87,7 @@ class LadderVAE(nn.Module):
88
87
  self.non_stochastic_version = False
89
88
  self.stochastic_skip = True
90
89
  self.learn_top_prior = True
91
- self.res_block_type = "bacdbacd"
90
+ self.res_block_type = "bacdbacd" # TODO remove !
92
91
  self.mode_pred = False
93
92
  self.logvar_lowerbound = -5
94
93
  self._var_clip_max = 20
@@ -151,8 +150,6 @@ class LadderVAE(nn.Module):
151
150
  self.mixed_rec_w = 0
152
151
  self.nbr_consistency_w = 0
153
152
 
154
- # Setting the loss_type
155
- self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
156
153
  # -------------------------------------------------------
157
154
 
158
155
  # -------------------------------------------------------
@@ -167,51 +164,6 @@ class LadderVAE(nn.Module):
167
164
  # -------------------------------------------------------
168
165
 
169
166
  # -------------------------------------------------------
170
- # Attributes from constructor arguments
171
- self.target_ch = target_ch
172
- self.use_uncond_mode_at = use_uncond_mode_at
173
-
174
- # Data mean and std used for normalization
175
- if isinstance(data_mean, np.ndarray):
176
- self.data_mean = torch.Tensor(data_mean)
177
- self.data_std = torch.Tensor(data_std)
178
- elif isinstance(data_mean, dict):
179
- for k in data_mean.keys():
180
- data_mean[k] = (
181
- torch.Tensor(data_mean[k])
182
- if not isinstance(data_mean[k], dict)
183
- else data_mean[k]
184
- )
185
- data_std[k] = (
186
- torch.Tensor(data_std[k])
187
- if not isinstance(data_std[k], dict)
188
- else data_std[k]
189
- )
190
- self.data_mean = data_mean
191
- self.data_std = data_std
192
- else:
193
- raise NotImplementedError(
194
- "data_mean and data_std must be either a numpy array or a dictionary"
195
- )
196
-
197
- assert self.data_std is not None
198
- assert self.data_mean is not None
199
-
200
- # Initialize the Noise Model
201
- self.likelihood_gm = self.likelihood_NM = None
202
- self.noiseModel = get_noise_model(
203
- enable_noise_model=self.enable_noise_model,
204
- model_type=self.model_type,
205
- noise_model_type=self.noise_model_type,
206
- noise_model_ch1_fpath=self.noise_model_ch1_fpath,
207
- noise_model_ch2_fpath=self.noise_model_ch2_fpath,
208
- noise_model_learnable=self.noise_model_learnable,
209
- )
210
-
211
- if self.noiseModel is None:
212
- self.likelihood_form = "gaussian"
213
- else:
214
- self.likelihood_form = "noise_model"
215
167
 
216
168
  # Calculate the downsampling happening in the network
217
169
  self.downsample = [1] * self.n_layers
@@ -246,7 +198,7 @@ class LadderVAE(nn.Module):
246
198
  )
247
199
 
248
200
  # Likelihood module
249
- self.likelihood = self.create_likelihood_module()
201
+ # self.likelihood = self.create_likelihood_module()
250
202
 
251
203
  # Output layer --> Project to target_ch many channels
252
204
  logvar_ch_needed = self.predict_logvar is not None
@@ -284,11 +236,11 @@ class LadderVAE(nn.Module):
284
236
  Parameters
285
237
  ----------
286
238
  init_stride: int
287
- The stride used by the intial Conv2d block.
239
+ The stride used by the initial Conv2d block.
288
240
  num_res_blocks: int, optional
289
241
  The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
290
242
  """
291
- nonlin = self.get_nonlin()
243
+ nonlin = get_activation(self.nonlin)
292
244
  modules = [
293
245
  nn.Conv2d(
294
246
  in_channels=self.color_ch,
@@ -301,7 +253,7 @@ class LadderVAE(nn.Module):
301
253
  ),
302
254
  stride=init_stride,
303
255
  ),
304
- nonlin(),
256
+ nonlin,
305
257
  ]
306
258
 
307
259
  for _ in range(num_res_blocks):
@@ -337,7 +289,7 @@ class LadderVAE(nn.Module):
337
289
  not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
338
290
  """
339
291
  multiscale_lowres_size_factor = 1
340
- nonlin = self.get_nonlin()
292
+ nonlin = get_activation(self.nonlin)
341
293
 
342
294
  bottom_up_layers = nn.ModuleList([])
343
295
  for i in range(self.n_layers):
@@ -409,7 +361,7 @@ class LadderVAE(nn.Module):
409
361
  ----------
410
362
  """
411
363
  top_down_layers = nn.ModuleList([])
412
- nonlin = self.get_nonlin()
364
+ nonlin = get_activation(self.nonlin)
413
365
  # NOTE: top-down layers are created starting from the bottom-most
414
366
  for i in range(self.n_layers):
415
367
  # Check if this is the top layer
@@ -477,7 +429,7 @@ class LadderVAE(nn.Module):
477
429
  TopDownDeterministicResBlock(
478
430
  c_in=self.decoder_n_filters,
479
431
  c_out=self.decoder_n_filters,
480
- nonlin=self.get_nonlin(),
432
+ nonlin=get_activation(self.nonlin),
481
433
  batchnorm=self.topdown_batchnorm,
482
434
  dropout=self.decoder_dropout,
483
435
  res_block_type=self.res_block_type,
@@ -489,34 +441,9 @@ class LadderVAE(nn.Module):
489
441
  )
490
442
  return nn.Sequential(*modules)
491
443
 
492
- def create_likelihood_module(self):
493
- """
494
- This method defines the likelihood module for the current LVAE model.
495
- The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
496
- """
497
- self.likelihood_gm = GaussianLikelihood(
498
- self.decoder_n_filters,
499
- self.target_ch,
500
- predict_logvar=self.predict_logvar,
501
- logvar_lowerbound=self.logvar_lowerbound,
502
- conv2d_bias=self.topdown_conv2d_bias,
503
- )
504
-
505
- self.likelihood_NM = None
506
- if self.enable_noise_model:
507
- self.likelihood_NM = NoiseModelLikelihood(
508
- self.decoder_n_filters,
509
- self.target_ch,
510
- self.data_mean,
511
- self.data_std,
512
- self.noiseModel,
513
- )
514
- if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
515
- return self.likelihood_gm
516
-
517
- return self.likelihood_NM
518
-
519
- def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
444
+ def _init_multires(
445
+ self, config=None
446
+ ) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
520
447
  """
521
448
  This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
522
449
  in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
@@ -531,7 +458,7 @@ class LadderVAE(nn.Module):
531
458
  In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
532
459
  """
533
460
  stride = 1 if self.no_initial_downscaling else 2
534
- nonlin = self.get_nonlin()
461
+ nonlin = get_activation(self.nonlin)
535
462
  if self._multiscale_count is None:
536
463
  self._multiscale_count = 1
537
464
 
@@ -556,7 +483,7 @@ class LadderVAE(nn.Module):
556
483
  padding=2,
557
484
  stride=stride,
558
485
  ),
559
- nonlin(),
486
+ nonlin,
560
487
  BottomUpDeterministicResBlock(
561
488
  c_in=self.encoder_n_filters,
562
489
  c_out=self.encoder_n_filters,
@@ -596,7 +523,7 @@ class LadderVAE(nn.Module):
596
523
  bottom_up_layers: nn.ModuleList,
597
524
  ) -> List[torch.Tensor]:
598
525
  """
599
- This method defines the forward pass throught the LVAE Encoder, the so-called
526
+ This method defines the forward pass through the LVAE Encoder, the so-called
600
527
  Bottom-Up pass.
601
528
 
602
529
  Parameters
@@ -642,7 +569,7 @@ class LadderVAE(nn.Module):
642
569
  final_top_down_layer: nn.Sequential = None,
643
570
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
644
571
  """
645
- This method defines the forward pass throught the LVAE Decoder, the so-called
572
+ This method defines the forward pass through the LVAE Decoder, the so-called
646
573
  Top-Down pass.
647
574
 
648
575
  Parameters
@@ -664,10 +591,10 @@ class LadderVAE(nn.Module):
664
591
  place in this case).
665
592
  top_down_layers: nn.ModuleList, optional
666
593
  A list of top-down layers to use in the top-down pass. If `None`, the method uses the
667
- default layers defined in the contructor.
594
+ default layers defined in the constructor.
668
595
  final_top_down_layer: nn.Sequential, optional
669
596
  The last top-down layer of the top-down pass. If `None`, the method uses the default
670
- layers defined in the contructor.
597
+ layers defined in the constructor.
671
598
  """
672
599
  if top_down_layers is None:
673
600
  top_down_layers = self.top_down_layers
@@ -742,7 +669,6 @@ class LadderVAE(nn.Module):
742
669
  # Whether the current layer should be sampled from the mode
743
670
  use_mode = i in mode_layers
744
671
  constant_out = i in constant_layers
745
- use_uncond_mode = i in self.use_uncond_mode_at
746
672
 
747
673
  # Input for skip connection
748
674
  skip_input = out # TODO or n? or both?
@@ -758,7 +684,6 @@ class LadderVAE(nn.Module):
758
684
  force_constant_output=constant_out,
759
685
  forced_latent=forced_latent[i],
760
686
  mode_pred=self.mode_pred,
761
- use_uncond_mode=use_uncond_mode,
762
687
  var_clip_max=self._var_clip_max,
763
688
  )
764
689
 
@@ -898,15 +823,6 @@ class LadderVAE(nn.Module):
898
823
  return x
899
824
 
900
825
  ### SET OF GETTERS
901
- def get_nonlin(self):
902
- nonlin = {
903
- "relu": nn.ReLU,
904
- "leakyrelu": nn.LeakyReLU,
905
- "elu": nn.ELU,
906
- "selu": nn.SELU,
907
- }
908
- return nonlin[self.nonlin]
909
-
910
826
  def get_padded_size(self, size):
911
827
  """
912
828
  Returns the smallest size (H, W) of the image with actual size given