careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,116 @@
1
+ from typing import Optional
2
+
3
+ import pytorch_lightning as L
4
+ from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher
5
+ from pytorch_lightning.loops.utilities import _no_grad_context
6
+ from pytorch_lightning.trainer import call
7
+ from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
8
+
9
+ from careamics.prediction import stitch_prediction
10
+
11
+
12
+ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
13
+ """
14
+ CAREamics prediction loop.
15
+
16
+ This class extends the PyTorch Lightning `_PredictionLoop` class to include
17
+ the stitching of the tiles into a single prediction result.
18
+ """
19
+
20
+ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
21
+ """
22
+ Calls `on_predict_epoch_end` hook.
23
+
24
+ Adapted from the parent method.
25
+
26
+ Returns
27
+ -------
28
+ the results for all dataloaders
29
+ """
30
+ trainer = self.trainer
31
+ call._call_callback_hooks(trainer, "on_predict_epoch_end")
32
+ call._call_lightning_module_hook(trainer, "on_predict_epoch_end")
33
+
34
+ if self.return_predictions:
35
+ ########################################################
36
+ ################ CAREamics specific code ###############
37
+ if len(self.predicted_array) == 1:
38
+ # TODO does this make sense to here? (force numpy array)
39
+ return self.predicted_array[0].numpy()
40
+ else:
41
+ # TODO revisit logic
42
+ return [element.numpy() for element in self.predicted_array]
43
+ ########################################################
44
+ return None
45
+
46
+ @_no_grad_context
47
+ def run(self) -> Optional[_PREDICT_OUTPUT]:
48
+ """
49
+ Runs the prediction loop.
50
+
51
+ Adapted from the parent method in order to stitch the predictions.
52
+
53
+ Returns
54
+ -------
55
+ Optional[_PREDICT_OUTPUT]
56
+ Prediction output
57
+ """
58
+ self.setup_data()
59
+ if self.skip:
60
+ return None
61
+ self.reset()
62
+ self.on_run_start()
63
+ data_fetcher = self._data_fetcher
64
+ assert data_fetcher is not None
65
+
66
+ self.predicted_array = []
67
+ self.tiles = []
68
+ self.stitching_data = []
69
+
70
+ while True:
71
+ try:
72
+ if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
73
+ dataloader_iter = next(data_fetcher)
74
+ # hook's batch_idx and dataloader_idx arguments correctness cannot
75
+ # be guaranteed in this setting
76
+ batch = data_fetcher._batch
77
+ batch_idx = data_fetcher._batch_idx
78
+ dataloader_idx = data_fetcher._dataloader_idx
79
+ else:
80
+ dataloader_iter = None
81
+ batch, batch_idx, dataloader_idx = next(data_fetcher)
82
+ self.batch_progress.is_last_batch = data_fetcher.done
83
+
84
+ # run step hooks
85
+ self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
86
+
87
+ ########################################################
88
+ ################ CAREamics specific code ###############
89
+ is_tiled = len(self.predictions[batch_idx]) == 2
90
+ if is_tiled:
91
+ # extract the last tile flag and the coordinates (crop and stitch)
92
+ last_tile, *stitch_data = self.predictions[batch_idx][1]
93
+
94
+ # append the tile and the coordinates to the lists
95
+ self.tiles.append(self.predictions[batch_idx][0])
96
+ self.stitching_data.append(stitch_data)
97
+
98
+ # if last tile, stitch the tiles and add array to the prediction
99
+ if any(last_tile):
100
+ predicted_batches = stitch_prediction(
101
+ self.tiles, self.stitching_data
102
+ )
103
+ self.predicted_array.append(predicted_batches)
104
+ self.tiles.clear()
105
+ self.stitching_data.clear()
106
+ else:
107
+ # simply add the prediction to the list
108
+ self.predicted_array.append(self.predictions[batch_idx])
109
+ ########################################################
110
+ except StopIteration:
111
+ break
112
+ finally:
113
+ self._restarting = False
114
+ return self.on_run_end()
115
+
116
+ # TODO predictions aren't stacked, list returned
@@ -1,4 +1,7 @@
1
1
  """Losses module."""
2
2
 
3
3
 
4
- from .loss_factory import create_loss_function as create_loss_function
4
+ from .loss_factory import loss_factory
5
+
6
+ # from .noise_model_factory import noise_model_factory as noise_model_factory
7
+ # from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
@@ -3,22 +3,21 @@ Loss factory module.
3
3
 
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
- from typing import Callable
6
+ from typing import Callable, Union
7
7
 
8
- from careamics.config import Configuration
9
- from careamics.config.algorithm import Loss
8
+ from ..config.support import SupportedLoss
9
+ from .losses import mae_loss, mse_loss, n2v_loss
10
10
 
11
- from .losses import n2v_loss
12
11
 
13
-
14
- def create_loss_function(config: Configuration) -> Callable:
15
- """
16
- Create loss function based on Configuration.
12
+ # TODO add tests
13
+ # TODO add custom?
14
+ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
15
+ """Return loss function.
17
16
 
18
17
  Parameters
19
18
  ----------
20
- config : Configuration
21
- Configuration.
19
+ loss: SupportedLoss
20
+ Requested loss.
22
21
 
23
22
  Returns
24
23
  -------
@@ -30,9 +29,20 @@ def create_loss_function(config: Configuration) -> Callable:
30
29
  NotImplementedError
31
30
  If the loss is unknown.
32
31
  """
33
- loss_type = config.algorithm.loss
34
-
35
- if loss_type == Loss.N2V:
32
+ if loss == SupportedLoss.N2V:
36
33
  return n2v_loss
34
+
35
+ # elif loss_type == SupportedLoss.PN2V:
36
+ # return pn2v_loss
37
+
38
+ elif loss == SupportedLoss.MAE:
39
+ return mae_loss
40
+
41
+ elif loss == SupportedLoss.MSE:
42
+ return mse_loss
43
+
44
+ # elif loss_type == SupportedLoss.DICE:
45
+ # return dice_loss
46
+
37
47
  else:
38
- raise NotImplementedError(f"Loss {loss_type} is not yet supported.")
48
+ raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -3,14 +3,34 @@ Loss submodule.
3
3
 
4
4
  This submodule contains the various losses used in CAREamics.
5
5
  """
6
+
6
7
  import torch
7
8
 
9
+ # TODO if we are only using the DiceLoss, can we just implement it?
10
+ # from segmentation_models_pytorch.losses import DiceLoss
11
+ from torch.nn import L1Loss, MSELoss
12
+
13
+
14
+ def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Mean squared error loss.
17
+
18
+ Returns
19
+ -------
20
+ torch.Tensor
21
+ Loss value.
22
+ """
23
+ loss = MSELoss()
24
+ return loss(samples, labels)
25
+
8
26
 
9
27
  def n2v_loss(
10
- samples: torch.Tensor, labels: torch.Tensor, masks: torch.Tensor, device: str
28
+ manipulated_patches: torch.Tensor,
29
+ original_patches: torch.Tensor,
30
+ masks: torch.Tensor,
11
31
  ) -> torch.Tensor:
12
32
  """
13
- N2V Loss function (see Eq.7 in Krull et al).
33
+ N2V Loss function described in A Krull et al 2018.
14
34
 
15
35
  Parameters
16
36
  ----------
@@ -20,15 +40,55 @@ def n2v_loss(
20
40
  Noisy patches.
21
41
  masks : torch.Tensor
22
42
  Array containing masked pixel locations.
23
- device : str
24
- Device to use.
25
43
 
26
44
  Returns
27
45
  -------
28
46
  torch.Tensor
29
47
  Loss value.
30
48
  """
31
- errors = (labels - samples) ** 2
49
+ errors = (original_patches - manipulated_patches) ** 2
32
50
  # Average over pixels and batch
33
51
  loss = torch.sum(errors * masks) / torch.sum(masks)
34
52
  return loss
53
+
54
+
55
+ def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ N2N Loss function described in to J Lehtinen et al 2018.
58
+
59
+ Parameters
60
+ ----------
61
+ samples : torch.Tensor
62
+ Raw patches.
63
+ labels : torch.Tensor
64
+ Different subset of noisy patches.
65
+
66
+ Returns
67
+ -------
68
+ torch.Tensor
69
+ Loss value.
70
+ """
71
+ loss = L1Loss()
72
+ return loss(samples, labels)
73
+
74
+
75
+ # def pn2v_loss(
76
+ # samples: torch.Tensor,
77
+ # labels: torch.Tensor,
78
+ # masks: torch.Tensor,
79
+ # noise_model: HistogramNoiseModel,
80
+ # ) -> torch.Tensor:
81
+ # """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
82
+ # likelihoods = noise_model.likelihood(labels, samples)
83
+ # likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
84
+
85
+ # # Average over pixels and batch
86
+ # loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
87
+ # return loss
88
+
89
+
90
+ # def dice_loss(
91
+ # samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
92
+ # ) -> torch.Tensor:
93
+ # """Dice loss function."""
94
+ # return DiceLoss(mode=mode)(samples, labels.long())
@@ -0,0 +1,40 @@
1
+ from typing import Type, Union
2
+
3
+ from ..config.noise_models import NoiseModel, NoiseModelType
4
+ from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
5
+
6
+
7
+ def noise_model_factory(
8
+ noise_config: NoiseModel,
9
+ ) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]:
10
+ """Create loss model based on Configuration.
11
+
12
+ Parameters
13
+ ----------
14
+ config : Configuration
15
+ Configuration.
16
+
17
+ Returns
18
+ -------
19
+ Noise model
20
+
21
+ Raises
22
+ ------
23
+ NotImplementedError
24
+ If the noise model is unknown.
25
+ """
26
+ noise_model_type = noise_config.model_type if noise_config else None
27
+
28
+ if noise_model_type == NoiseModelType.HIST:
29
+ return HistogramNoiseModel
30
+
31
+ elif noise_model_type == NoiseModelType.GMM:
32
+ return GaussianMixtureNoiseModel
33
+
34
+ elif noise_model_type is None:
35
+ return None
36
+
37
+ else:
38
+ raise NotImplementedError(
39
+ f"Noise model {noise_model_type} is not yet supported."
40
+ )