careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,98 @@
1
+ """
2
+ Loss submodule.
3
+
4
+ This submodule contains the various losses used in CAREamics.
5
+ """
6
+
7
+ import torch
8
+ from torch.nn import L1Loss, MSELoss
9
+
10
+
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Mean squared error loss.
14
+
15
+ Parameters
16
+ ----------
17
+ source : torch.Tensor
18
+ Source patches.
19
+ target : torch.Tensor
20
+ Target patches.
21
+
22
+ Returns
23
+ -------
24
+ torch.Tensor
25
+ Loss value.
26
+ """
27
+ loss = MSELoss()
28
+ return loss(source, target)
29
+
30
+
31
+ def n2v_loss(
32
+ manipulated_patches: torch.Tensor,
33
+ original_patches: torch.Tensor,
34
+ masks: torch.Tensor,
35
+ ) -> torch.Tensor:
36
+ """
37
+ N2V Loss function described in A Krull et al 2018.
38
+
39
+ Parameters
40
+ ----------
41
+ manipulated_patches : torch.Tensor
42
+ Patches with manipulated pixels.
43
+ original_patches : torch.Tensor
44
+ Noisy patches.
45
+ masks : torch.Tensor
46
+ Array containing masked pixel locations.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor
51
+ Loss value.
52
+ """
53
+ errors = (original_patches - manipulated_patches) ** 2
54
+ # Average over pixels and batch
55
+ loss = torch.sum(errors * masks) / torch.sum(masks)
56
+ return loss # TODO change output to dict ?
57
+
58
+
59
+ def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ N2N Loss function described in to J Lehtinen et al 2018.
62
+
63
+ Parameters
64
+ ----------
65
+ samples : torch.Tensor
66
+ Raw patches.
67
+ labels : torch.Tensor
68
+ Different subset of noisy patches.
69
+
70
+ Returns
71
+ -------
72
+ torch.Tensor
73
+ Loss value.
74
+ """
75
+ loss = L1Loss()
76
+ return loss(samples, labels)
77
+
78
+
79
+ # def pn2v_loss(
80
+ # samples: torch.Tensor,
81
+ # labels: torch.Tensor,
82
+ # masks: torch.Tensor,
83
+ # noise_model: HistogramNoiseModel,
84
+ # ) -> torch.Tensor:
85
+ # """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
86
+ # likelihoods = noise_model.likelihood(labels, samples)
87
+ # likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
88
+
89
+ # # Average over pixels and batch
90
+ # loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
91
+ # return loss
92
+
93
+
94
+ # def dice_loss(
95
+ # samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
96
+ # ) -> torch.Tensor:
97
+ # """Dice loss function."""
98
+ # return DiceLoss(mode=mode)(samples, labels.long())
@@ -0,0 +1,155 @@
1
+ """
2
+ Loss factory module.
3
+
4
+ This module contains a factory function for creating loss functions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
11
+
12
+ from torch import Tensor as tensor
13
+
14
+ from ..config.support import SupportedLoss
15
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
16
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ )
23
+ from careamics.models.lvae.noise_models import (
24
+ GaussianMixtureNoiseModel,
25
+ MultiChannelNoiseModel,
26
+ )
27
+
28
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
29
+
30
+
31
+ @dataclass
32
+ class FCNLossParameters:
33
+ """Dataclass for FCN loss."""
34
+
35
+ # TODO check
36
+ prediction: tensor
37
+ targets: tensor
38
+ mask: tensor
39
+ current_epoch: int
40
+ loss_weight: float
41
+
42
+
43
+ @dataclass # TODO why not pydantic?
44
+ class LVAELossParameters:
45
+ """Dataclass for LVAE loss."""
46
+
47
+ # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
48
+ # e.g., - weights, - kl_params, ...
49
+
50
+ noise_model_likelihood: Optional[NoiseModelLikelihood] = None
51
+ """Noise model likelihood instance."""
52
+ gaussian_likelihood: Optional[GaussianLikelihood] = None
53
+ """Gaussian likelihood instance."""
54
+ current_epoch: int = 0
55
+ """Current epoch in the training loop."""
56
+ reconstruction_weight: float = 1.0
57
+ """Weight for the reconstruction loss in the total net loss
58
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
+ musplit_weight: float = 0.0
60
+ """Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
61
+ denoisplit_weight: float = 1.0
62
+ """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
+ kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
+ """Type of KL divergence used as KL loss."""
65
+ kl_weight: float = 1.0
66
+ """Weight for the KL loss in the total net loss.
67
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
68
+ kl_annealing: bool = False
69
+ """Whether to apply KL loss annealing."""
70
+ kl_start: int = -1
71
+ """Epoch at which KL loss annealing starts."""
72
+ kl_annealtime: int = 10
73
+ """Number of epochs for which KL loss annealing is applied."""
74
+ non_stochastic: bool = False
75
+ """Whether to sample latents and compute KL."""
76
+
77
+
78
+ # TODO: really needed?
79
+ # like it is now, it is difficult to use, we need a way to specify the
80
+ # loss parameters in a more user-friendly way.
81
+ def loss_parameters_factory(
82
+ type: SupportedLoss,
83
+ ) -> Union[FCNLossParameters, LVAELossParameters]:
84
+ """Return loss parameters.
85
+
86
+ Parameters
87
+ ----------
88
+ type : SupportedLoss
89
+ Requested loss.
90
+
91
+ Returns
92
+ -------
93
+ Union[FCNLossParameters, LVAELossParameters]
94
+ Loss parameters.
95
+
96
+ Raises
97
+ ------
98
+ NotImplementedError
99
+ If the loss is unknown.
100
+ """
101
+ if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
102
+ return FCNLossParameters
103
+
104
+ elif type in [
105
+ SupportedLoss.MUSPLIT,
106
+ SupportedLoss.DENOISPLIT,
107
+ SupportedLoss.DENOISPLIT_MUSPLIT,
108
+ ]:
109
+ return LVAELossParameters # it returns the class, not an instance
110
+
111
+ else:
112
+ raise NotImplementedError(f"Loss {type} is not yet supported.")
113
+
114
+
115
+ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
116
+ """Return loss function.
117
+
118
+ Parameters
119
+ ----------
120
+ loss : Union[SupportedLoss, str]
121
+ Requested loss.
122
+
123
+ Returns
124
+ -------
125
+ Callable
126
+ Loss function.
127
+
128
+ Raises
129
+ ------
130
+ NotImplementedError
131
+ If the loss is unknown.
132
+ """
133
+ if loss == SupportedLoss.N2V:
134
+ return n2v_loss
135
+
136
+ # elif loss_type == SupportedLoss.PN2V:
137
+ # return pn2v_loss
138
+
139
+ elif loss == SupportedLoss.MAE:
140
+ return mae_loss
141
+
142
+ elif loss == SupportedLoss.MSE:
143
+ return mse_loss
144
+
145
+ elif loss == SupportedLoss.MUSPLIT:
146
+ return musplit_loss
147
+
148
+ elif loss == SupportedLoss.DENOISPLIT:
149
+ return denoisplit_loss
150
+
151
+ elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
152
+ return denoisplit_musplit_loss
153
+
154
+ else:
155
+ raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -0,0 +1 @@
1
+ """LVAE losses."""
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+
4
+ def free_bits_kl(
5
+ kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
6
+ ) -> torch.Tensor:
7
+ """Compute free-bits version of KL divergence.
8
+
9
+ This function ensures that the KL doesn't go to zero for any latent dimension.
10
+ Hence, it contributes to use latent variables more efficiently, leading to
11
+ better representation learning.
12
+
13
+ NOTE:
14
+ Takes in the KL with shape (batch size, layers), returns the KL with
15
+ free bits (for optimization) with shape (layers,), which is the average
16
+ free-bits KL per layer in the current batch.
17
+ If batch_average is False (default), the free bits are per layer and
18
+ per batch element. Otherwise, the free bits are still per layer, but
19
+ are assigned on average to the whole batch. In both cases, the batch
20
+ average is returned, so it's simply a matter of doing mean(clamp(KL))
21
+ or clamp(mean(KL)).
22
+
23
+ Parameters
24
+ ----------
25
+ kl : torch.Tensor
26
+ The KL divergence tensor with shape (batch size, layers).
27
+ free_bits : float
28
+ The free bits value. Set to 0.0 to disable free bits.
29
+ batch_average : bool
30
+ Whether to average over the batch before clamping to `free_bits`.
31
+ eps : float
32
+ A small value to avoid numerical instability.
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ The free-bits version of the KL divergence with shape (layers,).
38
+ """
39
+ assert kl.dim() == 2
40
+ if free_bits < eps:
41
+ return kl.mean(0)
42
+ if batch_average:
43
+ return kl.mean(0).clamp(min=free_bits)
44
+ return kl.clamp(min=free_bits).mean(0)
45
+
46
+
47
+ def get_kl_weight(
48
+ kl_annealing: bool,
49
+ kl_start: int,
50
+ kl_annealtime: int,
51
+ kl_weight: float,
52
+ current_epoch: int,
53
+ ) -> float:
54
+ """Compute the weight of the KL loss in case of annealing.
55
+
56
+ Parameters
57
+ ----------
58
+ kl_annealing : bool
59
+ Whether to use KL annealing.
60
+ kl_start : int
61
+ The epoch at which to start
62
+ kl_annealtime : int
63
+ The number of epochs for which annealing is applied.
64
+ kl_weight : float
65
+ The weight for the KL loss. If `None`, the weight is computed
66
+ using annealing, else it is set to a default of 1.
67
+ current_epoch : int
68
+ The current epoch.
69
+ """
70
+ if kl_annealing:
71
+ # calculate relative weight
72
+ kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
73
+ # clamp to [0,1]
74
+ kl_weight = min(max(0.0, kl_weight), 1.0)
75
+
76
+ # if the final weight is given, then apply that weight on top of it
77
+ if kl_weight is not None:
78
+ kl_weight = kl_weight * kl_weight
79
+ elif kl_weight is not None:
80
+ return kl_weight
81
+ else:
82
+ kl_weight = 1.0
83
+ return kl_weight