careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher
|
|
5
|
+
from pytorch_lightning.loops.utilities import _no_grad_context
|
|
6
|
+
from pytorch_lightning.trainer import call
|
|
7
|
+
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
|
|
8
|
+
|
|
9
|
+
from careamics.prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
13
|
+
"""
|
|
14
|
+
CAREamics prediction loop.
|
|
15
|
+
|
|
16
|
+
This class extends the PyTorch Lightning `_PredictionLoop` class to include
|
|
17
|
+
the stitching of the tiles into a single prediction result.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
|
|
21
|
+
"""
|
|
22
|
+
Calls `on_predict_epoch_end` hook.
|
|
23
|
+
|
|
24
|
+
Adapted from the parent method.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
the results for all dataloaders
|
|
29
|
+
"""
|
|
30
|
+
trainer = self.trainer
|
|
31
|
+
call._call_callback_hooks(trainer, "on_predict_epoch_end")
|
|
32
|
+
call._call_lightning_module_hook(trainer, "on_predict_epoch_end")
|
|
33
|
+
|
|
34
|
+
if self.return_predictions:
|
|
35
|
+
########################################################
|
|
36
|
+
################ CAREamics specific code ###############
|
|
37
|
+
if len(self.predicted_array) == 1:
|
|
38
|
+
# TODO does this make sense to here? (force numpy array)
|
|
39
|
+
return self.predicted_array[0].numpy()
|
|
40
|
+
else:
|
|
41
|
+
# TODO revisit logic
|
|
42
|
+
return [element.numpy() for element in self.predicted_array]
|
|
43
|
+
########################################################
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
@_no_grad_context
|
|
47
|
+
def run(self) -> Optional[_PREDICT_OUTPUT]:
|
|
48
|
+
"""
|
|
49
|
+
Runs the prediction loop.
|
|
50
|
+
|
|
51
|
+
Adapted from the parent method in order to stitch the predictions.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Optional[_PREDICT_OUTPUT]
|
|
56
|
+
Prediction output
|
|
57
|
+
"""
|
|
58
|
+
self.setup_data()
|
|
59
|
+
if self.skip:
|
|
60
|
+
return None
|
|
61
|
+
self.reset()
|
|
62
|
+
self.on_run_start()
|
|
63
|
+
data_fetcher = self._data_fetcher
|
|
64
|
+
assert data_fetcher is not None
|
|
65
|
+
|
|
66
|
+
self.predicted_array = []
|
|
67
|
+
self.tiles = []
|
|
68
|
+
self.stitching_data = []
|
|
69
|
+
|
|
70
|
+
while True:
|
|
71
|
+
try:
|
|
72
|
+
if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
|
|
73
|
+
dataloader_iter = next(data_fetcher)
|
|
74
|
+
# hook's batch_idx and dataloader_idx arguments correctness cannot
|
|
75
|
+
# be guaranteed in this setting
|
|
76
|
+
batch = data_fetcher._batch
|
|
77
|
+
batch_idx = data_fetcher._batch_idx
|
|
78
|
+
dataloader_idx = data_fetcher._dataloader_idx
|
|
79
|
+
else:
|
|
80
|
+
dataloader_iter = None
|
|
81
|
+
batch, batch_idx, dataloader_idx = next(data_fetcher)
|
|
82
|
+
self.batch_progress.is_last_batch = data_fetcher.done
|
|
83
|
+
|
|
84
|
+
# run step hooks
|
|
85
|
+
self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
|
|
86
|
+
|
|
87
|
+
########################################################
|
|
88
|
+
################ CAREamics specific code ###############
|
|
89
|
+
is_tiled = len(self.predictions[batch_idx]) == 2
|
|
90
|
+
if is_tiled:
|
|
91
|
+
# extract the last tile flag and the coordinates (crop and stitch)
|
|
92
|
+
last_tile, *stitch_data = self.predictions[batch_idx][1]
|
|
93
|
+
|
|
94
|
+
# append the tile and the coordinates to the lists
|
|
95
|
+
self.tiles.append(self.predictions[batch_idx][0])
|
|
96
|
+
self.stitching_data.append(stitch_data)
|
|
97
|
+
|
|
98
|
+
# if last tile, stitch the tiles and add array to the prediction
|
|
99
|
+
if any(last_tile):
|
|
100
|
+
predicted_batches = stitch_prediction(
|
|
101
|
+
self.tiles, self.stitching_data
|
|
102
|
+
)
|
|
103
|
+
self.predicted_array.append(predicted_batches)
|
|
104
|
+
self.tiles.clear()
|
|
105
|
+
self.stitching_data.clear()
|
|
106
|
+
else:
|
|
107
|
+
# simply add the prediction to the list
|
|
108
|
+
self.predicted_array.append(self.predictions[batch_idx])
|
|
109
|
+
########################################################
|
|
110
|
+
except StopIteration:
|
|
111
|
+
break
|
|
112
|
+
finally:
|
|
113
|
+
self._restarting = False
|
|
114
|
+
return self.on_run_end()
|
|
115
|
+
|
|
116
|
+
# TODO predictions aren't stacked, list returned
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
from .loss_factory import
|
|
4
|
+
from .loss_factory import loss_factory
|
|
5
|
+
|
|
6
|
+
# from .noise_model_factory import noise_model_factory as noise_model_factory
|
|
7
|
+
# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -3,22 +3,21 @@ Loss factory module.
|
|
|
3
3
|
|
|
4
4
|
This module contains a factory function for creating loss functions.
|
|
5
5
|
"""
|
|
6
|
-
from typing import Callable
|
|
6
|
+
from typing import Callable, Union
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
8
|
+
from ..config.support import SupportedLoss
|
|
9
|
+
from .losses import mae_loss, mse_loss, n2v_loss
|
|
10
10
|
|
|
11
|
-
from .losses import n2v_loss
|
|
12
11
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
# TODO add tests
|
|
13
|
+
# TODO add custom?
|
|
14
|
+
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
15
|
+
"""Return loss function.
|
|
17
16
|
|
|
18
17
|
Parameters
|
|
19
18
|
----------
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
loss: SupportedLoss
|
|
20
|
+
Requested loss.
|
|
22
21
|
|
|
23
22
|
Returns
|
|
24
23
|
-------
|
|
@@ -30,9 +29,20 @@ def create_loss_function(config: Configuration) -> Callable:
|
|
|
30
29
|
NotImplementedError
|
|
31
30
|
If the loss is unknown.
|
|
32
31
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
if loss_type == Loss.N2V:
|
|
32
|
+
if loss == SupportedLoss.N2V:
|
|
36
33
|
return n2v_loss
|
|
34
|
+
|
|
35
|
+
# elif loss_type == SupportedLoss.PN2V:
|
|
36
|
+
# return pn2v_loss
|
|
37
|
+
|
|
38
|
+
elif loss == SupportedLoss.MAE:
|
|
39
|
+
return mae_loss
|
|
40
|
+
|
|
41
|
+
elif loss == SupportedLoss.MSE:
|
|
42
|
+
return mse_loss
|
|
43
|
+
|
|
44
|
+
# elif loss_type == SupportedLoss.DICE:
|
|
45
|
+
# return dice_loss
|
|
46
|
+
|
|
37
47
|
else:
|
|
38
|
-
raise NotImplementedError(f"Loss {
|
|
48
|
+
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|
careamics/losses/losses.py
CHANGED
|
@@ -3,14 +3,34 @@ Loss submodule.
|
|
|
3
3
|
|
|
4
4
|
This submodule contains the various losses used in CAREamics.
|
|
5
5
|
"""
|
|
6
|
+
|
|
6
7
|
import torch
|
|
7
8
|
|
|
9
|
+
# TODO if we are only using the DiceLoss, can we just implement it?
|
|
10
|
+
# from segmentation_models_pytorch.losses import DiceLoss
|
|
11
|
+
from torch.nn import L1Loss, MSELoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
15
|
+
"""
|
|
16
|
+
Mean squared error loss.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
torch.Tensor
|
|
21
|
+
Loss value.
|
|
22
|
+
"""
|
|
23
|
+
loss = MSELoss()
|
|
24
|
+
return loss(samples, labels)
|
|
25
|
+
|
|
8
26
|
|
|
9
27
|
def n2v_loss(
|
|
10
|
-
|
|
28
|
+
manipulated_patches: torch.Tensor,
|
|
29
|
+
original_patches: torch.Tensor,
|
|
30
|
+
masks: torch.Tensor,
|
|
11
31
|
) -> torch.Tensor:
|
|
12
32
|
"""
|
|
13
|
-
N2V Loss function
|
|
33
|
+
N2V Loss function described in A Krull et al 2018.
|
|
14
34
|
|
|
15
35
|
Parameters
|
|
16
36
|
----------
|
|
@@ -20,15 +40,55 @@ def n2v_loss(
|
|
|
20
40
|
Noisy patches.
|
|
21
41
|
masks : torch.Tensor
|
|
22
42
|
Array containing masked pixel locations.
|
|
23
|
-
device : str
|
|
24
|
-
Device to use.
|
|
25
43
|
|
|
26
44
|
Returns
|
|
27
45
|
-------
|
|
28
46
|
torch.Tensor
|
|
29
47
|
Loss value.
|
|
30
48
|
"""
|
|
31
|
-
errors = (
|
|
49
|
+
errors = (original_patches - manipulated_patches) ** 2
|
|
32
50
|
# Average over pixels and batch
|
|
33
51
|
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
34
52
|
return loss
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
N2N Loss function described in to J Lehtinen et al 2018.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
samples : torch.Tensor
|
|
62
|
+
Raw patches.
|
|
63
|
+
labels : torch.Tensor
|
|
64
|
+
Different subset of noisy patches.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
torch.Tensor
|
|
69
|
+
Loss value.
|
|
70
|
+
"""
|
|
71
|
+
loss = L1Loss()
|
|
72
|
+
return loss(samples, labels)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# def pn2v_loss(
|
|
76
|
+
# samples: torch.Tensor,
|
|
77
|
+
# labels: torch.Tensor,
|
|
78
|
+
# masks: torch.Tensor,
|
|
79
|
+
# noise_model: HistogramNoiseModel,
|
|
80
|
+
# ) -> torch.Tensor:
|
|
81
|
+
# """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
|
|
82
|
+
# likelihoods = noise_model.likelihood(labels, samples)
|
|
83
|
+
# likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
|
|
84
|
+
|
|
85
|
+
# # Average over pixels and batch
|
|
86
|
+
# loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
|
|
87
|
+
# return loss
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# def dice_loss(
|
|
91
|
+
# samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
|
|
92
|
+
# ) -> torch.Tensor:
|
|
93
|
+
# """Dice loss function."""
|
|
94
|
+
# return DiceLoss(mode=mode)(samples, labels.long())
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Type, Union
|
|
2
|
+
|
|
3
|
+
from ..config.noise_models import NoiseModel, NoiseModelType
|
|
4
|
+
from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def noise_model_factory(
|
|
8
|
+
noise_config: NoiseModel,
|
|
9
|
+
) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]:
|
|
10
|
+
"""Create loss model based on Configuration.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
config : Configuration
|
|
15
|
+
Configuration.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Noise model
|
|
20
|
+
|
|
21
|
+
Raises
|
|
22
|
+
------
|
|
23
|
+
NotImplementedError
|
|
24
|
+
If the noise model is unknown.
|
|
25
|
+
"""
|
|
26
|
+
noise_model_type = noise_config.model_type if noise_config else None
|
|
27
|
+
|
|
28
|
+
if noise_model_type == NoiseModelType.HIST:
|
|
29
|
+
return HistogramNoiseModel
|
|
30
|
+
|
|
31
|
+
elif noise_model_type == NoiseModelType.GMM:
|
|
32
|
+
return GaussianMixtureNoiseModel
|
|
33
|
+
|
|
34
|
+
elif noise_model_type is None:
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
else:
|
|
38
|
+
raise NotImplementedError(
|
|
39
|
+
f"Noise model {noise_model_type} is not yet supported."
|
|
40
|
+
)
|