careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss factory module.
|
|
3
|
+
|
|
4
|
+
This module contains a factory function for creating loss functions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Callable, Union
|
|
8
|
+
|
|
9
|
+
from ..config.support import SupportedLoss
|
|
10
|
+
from .losses import mae_loss, mse_loss, n2v_loss
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO add tests
|
|
14
|
+
# TODO add custom?
|
|
15
|
+
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
16
|
+
"""Return loss function.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
loss : Union[SupportedLoss, str]
|
|
21
|
+
Requested loss.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
Callable
|
|
26
|
+
Loss function.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
NotImplementedError
|
|
31
|
+
If the loss is unknown.
|
|
32
|
+
"""
|
|
33
|
+
if loss == SupportedLoss.N2V:
|
|
34
|
+
return n2v_loss
|
|
35
|
+
|
|
36
|
+
# elif loss_type == SupportedLoss.PN2V:
|
|
37
|
+
# return pn2v_loss
|
|
38
|
+
|
|
39
|
+
elif loss == SupportedLoss.MAE:
|
|
40
|
+
return mae_loss
|
|
41
|
+
|
|
42
|
+
elif loss == SupportedLoss.MSE:
|
|
43
|
+
return mse_loss
|
|
44
|
+
|
|
45
|
+
# elif loss_type == SupportedLoss.DICE:
|
|
46
|
+
# return dice_loss
|
|
47
|
+
|
|
48
|
+
else:
|
|
49
|
+
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss submodule.
|
|
3
|
+
|
|
4
|
+
This submodule contains the various losses used in CAREamics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn import L1Loss, MSELoss
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
"""
|
|
13
|
+
Mean squared error loss.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
source : torch.Tensor
|
|
18
|
+
Source patches.
|
|
19
|
+
target : torch.Tensor
|
|
20
|
+
Target patches.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
torch.Tensor
|
|
25
|
+
Loss value.
|
|
26
|
+
"""
|
|
27
|
+
loss = MSELoss()
|
|
28
|
+
return loss(source, target)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def n2v_loss(
|
|
32
|
+
manipulated_patches: torch.Tensor,
|
|
33
|
+
original_patches: torch.Tensor,
|
|
34
|
+
masks: torch.Tensor,
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
N2V Loss function described in A Krull et al 2018.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
manipulated_patches : torch.Tensor
|
|
42
|
+
Patches with manipulated pixels.
|
|
43
|
+
original_patches : torch.Tensor
|
|
44
|
+
Noisy patches.
|
|
45
|
+
masks : torch.Tensor
|
|
46
|
+
Array containing masked pixel locations.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
torch.Tensor
|
|
51
|
+
Loss value.
|
|
52
|
+
"""
|
|
53
|
+
errors = (original_patches - manipulated_patches) ** 2
|
|
54
|
+
# Average over pixels and batch
|
|
55
|
+
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
56
|
+
return loss
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
"""
|
|
61
|
+
N2N Loss function described in to J Lehtinen et al 2018.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
samples : torch.Tensor
|
|
66
|
+
Raw patches.
|
|
67
|
+
labels : torch.Tensor
|
|
68
|
+
Different subset of noisy patches.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
torch.Tensor
|
|
73
|
+
Loss value.
|
|
74
|
+
"""
|
|
75
|
+
loss = L1Loss()
|
|
76
|
+
return loss(samples, labels)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# def pn2v_loss(
|
|
80
|
+
# samples: torch.Tensor,
|
|
81
|
+
# labels: torch.Tensor,
|
|
82
|
+
# masks: torch.Tensor,
|
|
83
|
+
# noise_model: HistogramNoiseModel,
|
|
84
|
+
# ) -> torch.Tensor:
|
|
85
|
+
# """Probabilistic N2V loss function described in A Krull et al., CVF (2019)."""
|
|
86
|
+
# likelihoods = noise_model.likelihood(labels, samples)
|
|
87
|
+
# likelihoods_avg = torch.log(torch.mean(likelihoods, dim=0, keepdim=True)[0, ...])
|
|
88
|
+
|
|
89
|
+
# # Average over pixels and batch
|
|
90
|
+
# loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
|
|
91
|
+
# return loss
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# def dice_loss(
|
|
95
|
+
# samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
|
|
96
|
+
# ) -> torch.Tensor:
|
|
97
|
+
# """Dice loss function."""
|
|
98
|
+
# return DiceLoss(mode=mode)(samples, labels.long())
|
|
File without changes
|