careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -4,74 +4,69 @@ Ladder VAE (LVAE) Model
4
4
  The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al."
5
5
  """
6
6
 
7
- from typing import Dict, Iterable, List, Tuple, Union
7
+ from collections.abc import Iterable
8
+ from typing import Dict, List, Tuple
8
9
 
9
- import ml_collections
10
10
  import numpy as np
11
11
  import torch
12
12
  import torch.nn as nn
13
13
 
14
+ from careamics.config.architectures import register_model
15
+
16
+ from ..activation import get_activation
14
17
  from .layers import (
15
18
  BottomUpDeterministicResBlock,
16
19
  BottomUpLayer,
17
20
  TopDownDeterministicResBlock,
18
21
  TopDownLayer,
19
22
  )
20
- from .likelihoods import GaussianLikelihood, NoiseModelLikelihood
21
- from .noise_models import get_noise_model
22
- from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor
23
+ from .utils import Interpolate, ModelType, crop_img_tensor, pad_img_tensor
23
24
 
24
25
 
26
+ @register_model("LVAE")
25
27
  class LadderVAE(nn.Module):
26
28
 
27
29
  def __init__(
28
30
  self,
29
- data_mean: Union[np.ndarray, Dict[str, torch.Tensor]],
30
- data_std: Union[np.ndarray, Dict[str, torch.Tensor]],
31
- config: ml_collections.ConfigDict,
32
- use_uncond_mode_at: Iterable[int] = [],
33
- target_ch: int = 2,
31
+ input_shape: int,
32
+ output_channels: int,
33
+ multiscale_count: int,
34
+ z_dims: List[int],
35
+ encoder_n_filters: int,
36
+ decoder_n_filters: int,
37
+ encoder_dropout: float,
38
+ decoder_dropout: float,
39
+ nonlinearity: str,
40
+ predict_logvar: bool,
41
+ analytical_kl: bool,
34
42
  ):
35
43
  """
36
44
  Constructor.
37
45
 
38
46
  Parameters
39
47
  ----------
40
- data_mean: Union[np.ndarray, Dict[str, torch.Tensor]]
41
- The mean of the data used for normalization.
42
- data_std: Union[np.ndarray, Dict[str, torch.Tensor]]
43
- The standard deviation of the data used for normalization.
44
- config: ml_collections.ConfigDict
45
- The configuration object of the model.
46
- use_uncond_mode_at: Iterable[int], optional
47
- A sequence of indexes associated to the layers in which sampling is disabled
48
- and the mode (mean value) is used instead. Default is `[]`.
49
- target_ch: int, optional
50
- The number of target channels (e.g., 1 for super-resolution or 2 for splitting).
51
- Default is `2`.
48
+
52
49
  """
53
50
  super().__init__()
54
51
 
55
52
  # -------------------------------------------------------
56
53
  # Customizable attributes
57
- self.image_size = config.data.image_size
58
- self._multiscale_count = config.data.multiscale_lowres_count
59
- self.z_dims = config.model.z_dims
60
- self.encoder_n_filters = config.model.n_filters
61
- self.decoder_n_filters = config.model.n_filters
62
- self.encoder_dropout = config.model.dropout
63
- self.decoder_dropout = config.model.dropout
64
- self.nonlin = config.model.nonlin
65
- self.predict_logvar = config.model.predict_logvar
66
- self.enable_noise_model = config.model.enable_noise_model
67
- self.noise_model_ch1_fpath = config.model.noise_model_ch1_fpath
68
- self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath
69
- self.analytical_kl = config.model.analytical_kl
54
+ self.image_size = input_shape
55
+ self.target_ch = output_channels
56
+ self._multiscale_count = multiscale_count
57
+ self.z_dims = z_dims
58
+ self.encoder_n_filters = encoder_n_filters
59
+ self.decoder_n_filters = decoder_n_filters
60
+ self.encoder_dropout = encoder_dropout
61
+ self.decoder_dropout = decoder_dropout
62
+ self.nonlin = nonlinearity
63
+ self.predict_logvar = predict_logvar
64
+ self.analytical_kl = analytical_kl
70
65
  # -------------------------------------------------------
71
66
 
72
67
  # -------------------------------------------------------
73
68
  # Model attributes -> Hardcoded
74
- self.model_type = ModelType.LadderVae
69
+ self.model_type = ModelType.LadderVae # TODO remove !
75
70
  self.encoder_blocks_per_layer = 1
76
71
  self.decoder_blocks_per_layer = 1
77
72
  self.bottomup_batchnorm = True
@@ -88,20 +83,13 @@ class LadderVAE(nn.Module):
88
83
  self.non_stochastic_version = False
89
84
  self.stochastic_skip = True
90
85
  self.learn_top_prior = True
91
- self.res_block_type = "bacdbacd"
86
+ self.res_block_type = "bacdbacd" # TODO remove !
92
87
  self.mode_pred = False
93
88
  self.logvar_lowerbound = -5
94
89
  self._var_clip_max = 20
95
90
  self._stochastic_use_naive_exponential = False
96
91
  self._enable_topdown_normalize_factor = True
97
92
 
98
- # Noise model attributes -> Hardcoded
99
- self.noise_model_type = "gmm"
100
- self.denoise_channel = (
101
- "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
102
- )
103
- self.noise_model_learnable = False
104
-
105
93
  # Attributes that handle LC -> Hardcoded
106
94
  self.enable_multiscale = (
107
95
  self._multiscale_count is not None and self._multiscale_count > 1
@@ -151,8 +139,6 @@ class LadderVAE(nn.Module):
151
139
  self.mixed_rec_w = 0
152
140
  self.nbr_consistency_w = 0
153
141
 
154
- # Setting the loss_type
155
- self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit)
156
142
  # -------------------------------------------------------
157
143
 
158
144
  # -------------------------------------------------------
@@ -167,51 +153,6 @@ class LadderVAE(nn.Module):
167
153
  # -------------------------------------------------------
168
154
 
169
155
  # -------------------------------------------------------
170
- # Attributes from constructor arguments
171
- self.target_ch = target_ch
172
- self.use_uncond_mode_at = use_uncond_mode_at
173
-
174
- # Data mean and std used for normalization
175
- if isinstance(data_mean, np.ndarray):
176
- self.data_mean = torch.Tensor(data_mean)
177
- self.data_std = torch.Tensor(data_std)
178
- elif isinstance(data_mean, dict):
179
- for k in data_mean.keys():
180
- data_mean[k] = (
181
- torch.Tensor(data_mean[k])
182
- if not isinstance(data_mean[k], dict)
183
- else data_mean[k]
184
- )
185
- data_std[k] = (
186
- torch.Tensor(data_std[k])
187
- if not isinstance(data_std[k], dict)
188
- else data_std[k]
189
- )
190
- self.data_mean = data_mean
191
- self.data_std = data_std
192
- else:
193
- raise NotImplementedError(
194
- "data_mean and data_std must be either a numpy array or a dictionary"
195
- )
196
-
197
- assert self.data_std is not None
198
- assert self.data_mean is not None
199
-
200
- # Initialize the Noise Model
201
- self.likelihood_gm = self.likelihood_NM = None
202
- self.noiseModel = get_noise_model(
203
- enable_noise_model=self.enable_noise_model,
204
- model_type=self.model_type,
205
- noise_model_type=self.noise_model_type,
206
- noise_model_ch1_fpath=self.noise_model_ch1_fpath,
207
- noise_model_ch2_fpath=self.noise_model_ch2_fpath,
208
- noise_model_learnable=self.noise_model_learnable,
209
- )
210
-
211
- if self.noiseModel is None:
212
- self.likelihood_form = "gaussian"
213
- else:
214
- self.likelihood_form = "noise_model"
215
156
 
216
157
  # Calculate the downsampling happening in the network
217
158
  self.downsample = [1] * self.n_layers
@@ -246,7 +187,7 @@ class LadderVAE(nn.Module):
246
187
  )
247
188
 
248
189
  # Likelihood module
249
- self.likelihood = self.create_likelihood_module()
190
+ # self.likelihood = self.create_likelihood_module()
250
191
 
251
192
  # Output layer --> Project to target_ch many channels
252
193
  logvar_ch_needed = self.predict_logvar is not None
@@ -284,11 +225,11 @@ class LadderVAE(nn.Module):
284
225
  Parameters
285
226
  ----------
286
227
  init_stride: int
287
- The stride used by the intial Conv2d block.
228
+ The stride used by the initial Conv2d block.
288
229
  num_res_blocks: int, optional
289
230
  The number of BottomUpDeterministicResBlocks to include in the layer, default is 1.
290
231
  """
291
- nonlin = self.get_nonlin()
232
+ nonlin = get_activation(self.nonlin)
292
233
  modules = [
293
234
  nn.Conv2d(
294
235
  in_channels=self.color_ch,
@@ -301,7 +242,7 @@ class LadderVAE(nn.Module):
301
242
  ),
302
243
  stride=init_stride,
303
244
  ),
304
- nonlin(),
245
+ nonlin,
305
246
  ]
306
247
 
307
248
  for _ in range(num_res_blocks):
@@ -337,7 +278,7 @@ class LadderVAE(nn.Module):
337
278
  not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow.
338
279
  """
339
280
  multiscale_lowres_size_factor = 1
340
- nonlin = self.get_nonlin()
281
+ nonlin = get_activation(self.nonlin)
341
282
 
342
283
  bottom_up_layers = nn.ModuleList([])
343
284
  for i in range(self.n_layers):
@@ -409,7 +350,7 @@ class LadderVAE(nn.Module):
409
350
  ----------
410
351
  """
411
352
  top_down_layers = nn.ModuleList([])
412
- nonlin = self.get_nonlin()
353
+ nonlin = get_activation(self.nonlin)
413
354
  # NOTE: top-down layers are created starting from the bottom-most
414
355
  for i in range(self.n_layers):
415
356
  # Check if this is the top layer
@@ -477,7 +418,7 @@ class LadderVAE(nn.Module):
477
418
  TopDownDeterministicResBlock(
478
419
  c_in=self.decoder_n_filters,
479
420
  c_out=self.decoder_n_filters,
480
- nonlin=self.get_nonlin(),
421
+ nonlin=get_activation(self.nonlin),
481
422
  batchnorm=self.topdown_batchnorm,
482
423
  dropout=self.decoder_dropout,
483
424
  res_block_type=self.res_block_type,
@@ -489,34 +430,9 @@ class LadderVAE(nn.Module):
489
430
  )
490
431
  return nn.Sequential(*modules)
491
432
 
492
- def create_likelihood_module(self):
493
- """
494
- This method defines the likelihood module for the current LVAE model.
495
- The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`.
496
- """
497
- self.likelihood_gm = GaussianLikelihood(
498
- self.decoder_n_filters,
499
- self.target_ch,
500
- predict_logvar=self.predict_logvar,
501
- logvar_lowerbound=self.logvar_lowerbound,
502
- conv2d_bias=self.topdown_conv2d_bias,
503
- )
504
-
505
- self.likelihood_NM = None
506
- if self.enable_noise_model:
507
- self.likelihood_NM = NoiseModelLikelihood(
508
- self.decoder_n_filters,
509
- self.target_ch,
510
- self.data_mean,
511
- self.data_std,
512
- self.noiseModel,
513
- )
514
- if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None:
515
- return self.likelihood_gm
516
-
517
- return self.likelihood_NM
518
-
519
- def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList:
433
+ def _init_multires(
434
+ self, config=None
435
+ ) -> nn.ModuleList: # TODO config: ml_collections.ConfigDict refactor
520
436
  """
521
437
  This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels
522
438
  in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer
@@ -531,7 +447,7 @@ class LadderVAE(nn.Module):
531
447
  In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs.
532
448
  """
533
449
  stride = 1 if self.no_initial_downscaling else 2
534
- nonlin = self.get_nonlin()
450
+ nonlin = get_activation(self.nonlin)
535
451
  if self._multiscale_count is None:
536
452
  self._multiscale_count = 1
537
453
 
@@ -556,7 +472,7 @@ class LadderVAE(nn.Module):
556
472
  padding=2,
557
473
  stride=stride,
558
474
  ),
559
- nonlin(),
475
+ nonlin,
560
476
  BottomUpDeterministicResBlock(
561
477
  c_in=self.encoder_n_filters,
562
478
  c_out=self.encoder_n_filters,
@@ -596,7 +512,7 @@ class LadderVAE(nn.Module):
596
512
  bottom_up_layers: nn.ModuleList,
597
513
  ) -> List[torch.Tensor]:
598
514
  """
599
- This method defines the forward pass throught the LVAE Encoder, the so-called
515
+ This method defines the forward pass through the LVAE Encoder, the so-called
600
516
  Bottom-Up pass.
601
517
 
602
518
  Parameters
@@ -642,7 +558,7 @@ class LadderVAE(nn.Module):
642
558
  final_top_down_layer: nn.Sequential = None,
643
559
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
644
560
  """
645
- This method defines the forward pass throught the LVAE Decoder, the so-called
561
+ This method defines the forward pass through the LVAE Decoder, the so-called
646
562
  Top-Down pass.
647
563
 
648
564
  Parameters
@@ -664,10 +580,10 @@ class LadderVAE(nn.Module):
664
580
  place in this case).
665
581
  top_down_layers: nn.ModuleList, optional
666
582
  A list of top-down layers to use in the top-down pass. If `None`, the method uses the
667
- default layers defined in the contructor.
583
+ default layers defined in the constructor.
668
584
  final_top_down_layer: nn.Sequential, optional
669
585
  The last top-down layer of the top-down pass. If `None`, the method uses the default
670
- layers defined in the contructor.
586
+ layers defined in the constructor.
671
587
  """
672
588
  if top_down_layers is None:
673
589
  top_down_layers = self.top_down_layers
@@ -742,7 +658,6 @@ class LadderVAE(nn.Module):
742
658
  # Whether the current layer should be sampled from the mode
743
659
  use_mode = i in mode_layers
744
660
  constant_out = i in constant_layers
745
- use_uncond_mode = i in self.use_uncond_mode_at
746
661
 
747
662
  # Input for skip connection
748
663
  skip_input = out # TODO or n? or both?
@@ -758,7 +673,6 @@ class LadderVAE(nn.Module):
758
673
  force_constant_output=constant_out,
759
674
  forced_latent=forced_latent[i],
760
675
  mode_pred=self.mode_pred,
761
- use_uncond_mode=use_uncond_mode,
762
676
  var_clip_max=self._var_clip_max,
763
677
  )
764
678
 
@@ -881,11 +795,18 @@ class LadderVAE(nn.Module):
881
795
 
882
796
  # return samples
883
797
 
884
- # def reset_for_different_output_size(self, output_size):
885
- # for i in range(self.n_layers):
886
- # sz = output_size // 2**(1 + i)
887
- # self.bottom_up_layers[i].output_expected_shape = (sz, sz)
888
- # self.top_down_layers[i].latent_shape = (output_size, output_size)
798
+ def reset_for_different_output_size(self, output_size: int) -> None:
799
+ """Reset shape of output and latent tensors for different output size.
800
+
801
+ Used during evaluation to reset expected shapes of tensors when
802
+ input/output shape changes.
803
+ For instance, it is needed when the model was trained on, say, 64x64 sized
804
+ patches, but prediction is done on 128x128 patches.
805
+ """
806
+ for i in range(self.n_layers):
807
+ sz = output_size // 2 ** (1 + i)
808
+ self.bottom_up_layers[i].output_expected_shape = (sz, sz)
809
+ self.top_down_layers[i].latent_shape = (output_size, output_size)
889
810
 
890
811
  def pad_input(self, x):
891
812
  """
@@ -898,15 +819,6 @@ class LadderVAE(nn.Module):
898
819
  return x
899
820
 
900
821
  ### SET OF GETTERS
901
- def get_nonlin(self):
902
- nonlin = {
903
- "relu": nn.ReLU,
904
- "leakyrelu": nn.LeakyReLU,
905
- "elu": nn.ELU,
906
- "selu": nn.SELU,
907
- }
908
- return nonlin[self.nonlin]
909
-
910
822
  def get_padded_size(self, size):
911
823
  """
912
824
  Returns the smallest size (H, W) of the image with actual size given