careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +24 -7
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +55 -4
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +41 -4
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/optimizer_models.py +1 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/training_model.py +0 -2
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +229 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +451 -0
- careamics/dataset_ng/legacy_interoperability.py +170 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +678 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
- careamics/lightning/lightning_module.py +5 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/compose.py +1 -0
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/normalize.py +18 -7
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +25 -11
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""CARE Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Union
|
|
5
|
+
|
|
6
|
+
from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
|
|
7
|
+
from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
|
|
8
|
+
from careamics.config.support import SupportedLoss
|
|
9
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
10
|
+
from careamics.losses import mae_loss, mse_loss
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
from .unet_module import UnetModule
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CAREModule(UnetModule):
|
|
19
|
+
"""CAREamics PyTorch Lightning module for CARE algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : CAREAlgorithm or dict
|
|
24
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate CARE DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : CAREAlgorithm or dict
|
|
34
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
|
|
35
|
+
a dictionary.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(algorithm_config)
|
|
38
|
+
assert isinstance(
|
|
39
|
+
algorithm_config, CAREAlgorithm | N2NAlgorithm
|
|
40
|
+
), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
|
|
41
|
+
loss = algorithm_config.loss
|
|
42
|
+
if loss == SupportedLoss.MAE:
|
|
43
|
+
self.loss_func: Callable = mae_loss
|
|
44
|
+
elif loss == SupportedLoss.MSE:
|
|
45
|
+
self.loss_func = mse_loss
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported loss for Care: {loss}")
|
|
48
|
+
|
|
49
|
+
def training_step(
|
|
50
|
+
self,
|
|
51
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
52
|
+
batch_idx: Any,
|
|
53
|
+
) -> Any:
|
|
54
|
+
"""Training step for CARE module.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
59
|
+
A tuple containing the input data and the target data.
|
|
60
|
+
batch_idx : Any
|
|
61
|
+
The index of the current batch in the training loop.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Any
|
|
66
|
+
The loss value computed for the current batch.
|
|
67
|
+
"""
|
|
68
|
+
# TODO: add validation to determine if target is initialized
|
|
69
|
+
x, target = batch[0], batch[1]
|
|
70
|
+
|
|
71
|
+
prediction = self.model(x.data)
|
|
72
|
+
loss = self.loss_func(prediction, target.data)
|
|
73
|
+
|
|
74
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
75
|
+
|
|
76
|
+
return loss
|
|
77
|
+
|
|
78
|
+
def validation_step(
|
|
79
|
+
self,
|
|
80
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
81
|
+
batch_idx: Any,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Validation step for CARE module.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
88
|
+
A tuple containing the input data and the target data.
|
|
89
|
+
batch_idx : Any
|
|
90
|
+
The index of the current batch in the training loop.
|
|
91
|
+
"""
|
|
92
|
+
x, target = batch[0], batch[1]
|
|
93
|
+
|
|
94
|
+
prediction = self.model(x.data)
|
|
95
|
+
val_loss = self.loss_func(prediction, target.data)
|
|
96
|
+
self.metrics(prediction, target.data)
|
|
97
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Noise2Void Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
from careamics.config import (
|
|
6
|
+
N2VAlgorithm,
|
|
7
|
+
)
|
|
8
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
9
|
+
from careamics.losses import n2v_loss
|
|
10
|
+
from careamics.transforms import N2VManipulateTorch
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
from .unet_module import UnetModule
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class N2VModule(UnetModule):
|
|
19
|
+
"""CAREamics PyTorch Lightning module for N2V algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : N2VAlgorithm or dict
|
|
24
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate N2V DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : N2VAlgorithm or dict
|
|
34
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
35
|
+
dictionary.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(algorithm_config)
|
|
38
|
+
|
|
39
|
+
assert isinstance(
|
|
40
|
+
algorithm_config, N2VAlgorithm
|
|
41
|
+
), "algorithm_config must be a N2VAlgorithm"
|
|
42
|
+
|
|
43
|
+
self.n2v_manipulate = N2VManipulateTorch(
|
|
44
|
+
n2v_manipulate_config=algorithm_config.n2v_config
|
|
45
|
+
)
|
|
46
|
+
self.loss_func = n2v_loss
|
|
47
|
+
|
|
48
|
+
def _load_best_checkpoint(self) -> None:
|
|
49
|
+
"""Load the best checkpoint for N2V model."""
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Loading best checkpoint for N2V model. Note that for N2V, "
|
|
52
|
+
"the checkpoint with the best validation metrics may not necessarily "
|
|
53
|
+
"have the best denoising performance."
|
|
54
|
+
)
|
|
55
|
+
super()._load_best_checkpoint()
|
|
56
|
+
|
|
57
|
+
def training_step(
|
|
58
|
+
self,
|
|
59
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
60
|
+
batch_idx: Any,
|
|
61
|
+
) -> Any:
|
|
62
|
+
"""Training step for N2V model.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
67
|
+
A tuple containing the input data and the target data.
|
|
68
|
+
batch_idx : Any
|
|
69
|
+
The index of the current batch in the training loop.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Any
|
|
74
|
+
The loss value for the current training step.
|
|
75
|
+
"""
|
|
76
|
+
x = batch[0]
|
|
77
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
78
|
+
prediction = self.model(x_masked)
|
|
79
|
+
loss = self.loss_func(prediction, x_original, mask)
|
|
80
|
+
|
|
81
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
82
|
+
|
|
83
|
+
return loss
|
|
84
|
+
|
|
85
|
+
def validation_step(
|
|
86
|
+
self,
|
|
87
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
88
|
+
batch_idx: Any,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Validation step for N2V model.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
95
|
+
A tuple containing the input data and the target data.
|
|
96
|
+
batch_idx : Any
|
|
97
|
+
The index of the current batch in the training loop.
|
|
98
|
+
"""
|
|
99
|
+
x = batch[0]
|
|
100
|
+
|
|
101
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
102
|
+
prediction = self.model(x_masked)
|
|
103
|
+
|
|
104
|
+
val_loss = self.loss_func(prediction, x_original, mask)
|
|
105
|
+
self.metrics(prediction, x_original)
|
|
106
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Generic UNet Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torchmetrics import MetricCollection
|
|
9
|
+
from torchmetrics.image import PeakSignalNoiseRatio
|
|
10
|
+
|
|
11
|
+
from careamics.config import algorithm_factory
|
|
12
|
+
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
13
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
14
|
+
from careamics.models.unet import UNet
|
|
15
|
+
from careamics.transforms import Denormalize
|
|
16
|
+
from careamics.utils.logging import get_logger
|
|
17
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class UnetModule(L.LightningModule):
|
|
23
|
+
"""CAREamics PyTorch Lightning module for UNet based algorithms.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
28
|
+
Configuration for the algorithm, either as an instance of a specific algorithm
|
|
29
|
+
class or a dictionary that can be converted to an algorithm instance.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Instantiate UNet DataModule.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
40
|
+
Configuration for the algorithm, either as an instance of a specific
|
|
41
|
+
algorithm class or a dictionary that can be converted to an algorithm
|
|
42
|
+
instance.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
if isinstance(algorithm_config, dict):
|
|
47
|
+
algorithm_config = algorithm_factory(algorithm_config)
|
|
48
|
+
|
|
49
|
+
self.config = algorithm_config
|
|
50
|
+
self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
|
|
51
|
+
|
|
52
|
+
self._best_checkpoint_loaded = False
|
|
53
|
+
|
|
54
|
+
# TODO: how to support metric evaluation better
|
|
55
|
+
self.metrics = MetricCollection(PeakSignalNoiseRatio())
|
|
56
|
+
|
|
57
|
+
def forward(self, x: Any) -> Any:
|
|
58
|
+
"""Default forward method.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
x : Any
|
|
63
|
+
Input data.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Any
|
|
68
|
+
Output from the model.
|
|
69
|
+
"""
|
|
70
|
+
return self.model(x)
|
|
71
|
+
|
|
72
|
+
def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
|
|
73
|
+
"""Log training statistics.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
loss : Any
|
|
78
|
+
The loss value for the current training step.
|
|
79
|
+
batch_size : Any
|
|
80
|
+
The size of the batch used in the current training step.
|
|
81
|
+
"""
|
|
82
|
+
self.log(
|
|
83
|
+
"train_loss",
|
|
84
|
+
loss,
|
|
85
|
+
on_step=True,
|
|
86
|
+
on_epoch=True,
|
|
87
|
+
prog_bar=True,
|
|
88
|
+
logger=True,
|
|
89
|
+
batch_size=batch_size,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
optimizer = self.optimizers()
|
|
93
|
+
if isinstance(optimizer, list):
|
|
94
|
+
current_lr = optimizer[0].param_groups[0]["lr"]
|
|
95
|
+
else:
|
|
96
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
97
|
+
self.log(
|
|
98
|
+
"learning_rate",
|
|
99
|
+
current_lr,
|
|
100
|
+
on_step=False,
|
|
101
|
+
on_epoch=True,
|
|
102
|
+
logger=True,
|
|
103
|
+
batch_size=batch_size,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
|
|
107
|
+
"""Log validation statistics.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
loss : Any
|
|
112
|
+
The loss value for the current validation step.
|
|
113
|
+
batch_size : Any
|
|
114
|
+
The size of the batch used in the current validation step.
|
|
115
|
+
"""
|
|
116
|
+
self.log(
|
|
117
|
+
"val_loss",
|
|
118
|
+
loss,
|
|
119
|
+
on_step=False,
|
|
120
|
+
on_epoch=True,
|
|
121
|
+
prog_bar=True,
|
|
122
|
+
logger=True,
|
|
123
|
+
batch_size=batch_size,
|
|
124
|
+
)
|
|
125
|
+
self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
|
|
126
|
+
|
|
127
|
+
def _load_best_checkpoint(self) -> None:
|
|
128
|
+
"""Load the best checkpoint from the trainer's checkpoint callback."""
|
|
129
|
+
if (
|
|
130
|
+
not hasattr(self.trainer, "checkpoint_callback")
|
|
131
|
+
or self.trainer.checkpoint_callback is None
|
|
132
|
+
):
|
|
133
|
+
logger.warning("No checkpoint callback found, cannot load best checkpoint.")
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
best_model_path = self.trainer.checkpoint_callback.best_model_path
|
|
137
|
+
if best_model_path and best_model_path != "":
|
|
138
|
+
logger.info(f"Loading best checkpoint from: {best_model_path}")
|
|
139
|
+
model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
|
|
140
|
+
self.load_state_dict(model_state)
|
|
141
|
+
else:
|
|
142
|
+
logger.warning("No best checkpoint found.")
|
|
143
|
+
|
|
144
|
+
def predict_step(
|
|
145
|
+
self,
|
|
146
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
147
|
+
batch_idx: Any,
|
|
148
|
+
load_best_checkpoint=False,
|
|
149
|
+
) -> Any:
|
|
150
|
+
"""Default predict step.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
155
|
+
A tuple containing the input data and optionally the target data.
|
|
156
|
+
batch_idx : Any
|
|
157
|
+
The index of the current batch in the prediction loop.
|
|
158
|
+
load_best_checkpoint : bool, default=False
|
|
159
|
+
Whether to load the best checkpoint before making predictions.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
Any
|
|
164
|
+
The output batch containing the predictions.
|
|
165
|
+
"""
|
|
166
|
+
if self._best_checkpoint_loaded is False and load_best_checkpoint:
|
|
167
|
+
self._load_best_checkpoint()
|
|
168
|
+
self._best_checkpoint_loaded = True
|
|
169
|
+
|
|
170
|
+
x = batch[0]
|
|
171
|
+
# TODO: add TTA
|
|
172
|
+
prediction = self.model(x.data).cpu().numpy()
|
|
173
|
+
|
|
174
|
+
means = self._trainer.datamodule.stats.means
|
|
175
|
+
stds = self._trainer.datamodule.stats.stds
|
|
176
|
+
denormalize = Denormalize(
|
|
177
|
+
image_means=means,
|
|
178
|
+
image_stds=stds,
|
|
179
|
+
)
|
|
180
|
+
denormalized_output = denormalize(prediction)
|
|
181
|
+
|
|
182
|
+
output_batch = ImageRegionData(
|
|
183
|
+
data=denormalized_output,
|
|
184
|
+
source=x.source,
|
|
185
|
+
data_shape=x.data_shape,
|
|
186
|
+
dtype=x.dtype,
|
|
187
|
+
axes=x.axes,
|
|
188
|
+
region_spec=x.region_spec,
|
|
189
|
+
)
|
|
190
|
+
return output_batch
|
|
191
|
+
|
|
192
|
+
def configure_optimizers(self) -> Any:
|
|
193
|
+
"""Configure optimizers.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
Any
|
|
198
|
+
A dictionary containing the optimizer and learning rate scheduler.
|
|
199
|
+
"""
|
|
200
|
+
optimizer_func = get_optimizer(self.config.optimizer.name)
|
|
201
|
+
optimizer = optimizer_func(
|
|
202
|
+
self.model.parameters(), **self.config.optimizer.parameters
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
scheduler_func = get_scheduler(self.config.lr_scheduler.name)
|
|
206
|
+
scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
|
|
207
|
+
|
|
208
|
+
return {
|
|
209
|
+
"optimizer": optimizer,
|
|
210
|
+
"lr_scheduler": scheduler,
|
|
211
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
212
|
+
}
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Literal, Optional, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pytorch_lightning as L
|
|
@@ -148,6 +149,9 @@ class FCNModule(L.LightningModule):
|
|
|
148
149
|
self.log(
|
|
149
150
|
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
150
151
|
)
|
|
152
|
+
optimizer = self.optimizers()
|
|
153
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
154
|
+
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
151
155
|
return loss
|
|
152
156
|
|
|
153
157
|
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Training and validation Lightning data modules."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -6,8 +6,9 @@ This module contains a factory function for creating loss functions.
|
|
|
6
6
|
|
|
7
7
|
from __future__ import annotations
|
|
8
8
|
|
|
9
|
+
from collections.abc import Callable
|
|
9
10
|
from dataclasses import dataclass
|
|
10
|
-
from typing import
|
|
11
|
+
from typing import Union
|
|
11
12
|
|
|
12
13
|
from torch import Tensor as tensor
|
|
13
14
|
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .config import DatasetConfig
|
|
2
2
|
from .lc_dataset import LCMultiChDloader
|
|
3
|
+
from .ms_dataset_ref import MultiChDloaderRef
|
|
4
|
+
from .multich_dataset import MultiChDloader
|
|
5
|
+
from .multicrop_dset import MultiCropDset
|
|
3
6
|
from .multifile_dataset import MultiFileDset
|
|
4
|
-
from .
|
|
5
|
-
from .types import DataType, DataSplitType, TilingMode
|
|
7
|
+
from .types import DataSplitType, DataType, TilingMode
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"DatasetConfig",
|
|
9
11
|
"MultiChDloader",
|
|
10
12
|
"LCMultiChDloader",
|
|
11
13
|
"MultiFileDset",
|
|
14
|
+
"MultiCropDset",
|
|
15
|
+
"MultiChDloaderRef",
|
|
16
|
+
"LCMultiChDloaderRef",
|
|
12
17
|
"DataType",
|
|
13
18
|
"DataSplitType",
|
|
14
19
|
"TilingMode",
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
@@ -43,7 +43,7 @@ class DatasetConfig(BaseModel):
|
|
|
43
43
|
image_size: tuple # TODO: revisit, new model_config uses tuple
|
|
44
44
|
"""Size of one patch of data"""
|
|
45
45
|
|
|
46
|
-
grid_size: Optional[int] = None
|
|
46
|
+
grid_size: Optional[Union[int, tuple[int, int, int]]] = None
|
|
47
47
|
"""Frame is divided into square grids of this size. A patch centered on a grid
|
|
48
48
|
having size `image_size` is returned. Grid size not used in training,
|
|
49
49
|
used only during val / test, grid size controls the overlap of the patches"""
|
|
@@ -82,7 +82,7 @@ class DatasetConfig(BaseModel):
|
|
|
82
82
|
# TODO: why is this not used?
|
|
83
83
|
enable_rotation_aug: Optional[bool] = False
|
|
84
84
|
|
|
85
|
-
max_val: Optional[float] = None
|
|
85
|
+
max_val: Optional[Union[float, tuple]] = None
|
|
86
86
|
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
87
87
|
externally set for val and test splits."""
|
|
88
88
|
|