careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset/dataset_utils/running_stats.py +7 -3
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torchmetrics import MetricCollection
|
|
7
|
+
from torchmetrics.image import PeakSignalNoiseRatio
|
|
8
|
+
|
|
9
|
+
from careamics.config import algorithm_factory
|
|
10
|
+
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
11
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
12
|
+
from careamics.models.unet import UNet
|
|
13
|
+
from careamics.transforms import Denormalize
|
|
14
|
+
from careamics.utils.logging import get_logger
|
|
15
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class UnetModule(L.LightningModule):
|
|
21
|
+
"""CAREamics PyTorch Lightning module for UNet based algorithms."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
if isinstance(algorithm_config, dict):
|
|
29
|
+
algorithm_config = algorithm_factory(algorithm_config)
|
|
30
|
+
|
|
31
|
+
self.config = algorithm_config
|
|
32
|
+
self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
|
|
33
|
+
|
|
34
|
+
self._best_checkpoint_loaded = False
|
|
35
|
+
|
|
36
|
+
# TODO: how to support metric evaluation better
|
|
37
|
+
self.metrics = MetricCollection(PeakSignalNoiseRatio())
|
|
38
|
+
|
|
39
|
+
def forward(self, x: Any) -> Any:
|
|
40
|
+
"""Default forward method."""
|
|
41
|
+
return self.model(x)
|
|
42
|
+
|
|
43
|
+
def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
|
|
44
|
+
self.log(
|
|
45
|
+
"train_loss",
|
|
46
|
+
loss,
|
|
47
|
+
on_step=True,
|
|
48
|
+
on_epoch=True,
|
|
49
|
+
prog_bar=True,
|
|
50
|
+
logger=True,
|
|
51
|
+
batch_size=batch_size,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
optimizer = self.optimizers()
|
|
55
|
+
if isinstance(optimizer, list):
|
|
56
|
+
current_lr = optimizer[0].param_groups[0]["lr"]
|
|
57
|
+
else:
|
|
58
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
59
|
+
self.log(
|
|
60
|
+
"learning_rate",
|
|
61
|
+
current_lr,
|
|
62
|
+
on_step=False,
|
|
63
|
+
on_epoch=True,
|
|
64
|
+
logger=True,
|
|
65
|
+
batch_size=batch_size,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
|
|
69
|
+
self.log(
|
|
70
|
+
"val_loss",
|
|
71
|
+
loss,
|
|
72
|
+
on_step=False,
|
|
73
|
+
on_epoch=True,
|
|
74
|
+
prog_bar=True,
|
|
75
|
+
logger=True,
|
|
76
|
+
batch_size=batch_size,
|
|
77
|
+
)
|
|
78
|
+
self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
|
|
79
|
+
|
|
80
|
+
def _load_best_checkpoint(self) -> None:
|
|
81
|
+
if (
|
|
82
|
+
not hasattr(self.trainer, "checkpoint_callback")
|
|
83
|
+
or self.trainer.checkpoint_callback is None
|
|
84
|
+
):
|
|
85
|
+
logger.warning("No checkpoint callback found, cannot load best checkpoint.")
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
best_model_path = self.trainer.checkpoint_callback.best_model_path
|
|
89
|
+
if best_model_path and best_model_path != "":
|
|
90
|
+
logger.info(f"Loading best checkpoint from: {best_model_path}")
|
|
91
|
+
model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
|
|
92
|
+
self.load_state_dict(model_state)
|
|
93
|
+
else:
|
|
94
|
+
logger.warning("No best checkpoint found.")
|
|
95
|
+
|
|
96
|
+
def predict_step(
|
|
97
|
+
self,
|
|
98
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
99
|
+
batch_idx: Any,
|
|
100
|
+
load_best_checkpoint=False,
|
|
101
|
+
) -> Any:
|
|
102
|
+
"""Default predict step."""
|
|
103
|
+
if self._best_checkpoint_loaded is False and load_best_checkpoint:
|
|
104
|
+
self._load_best_checkpoint()
|
|
105
|
+
self._best_checkpoint_loaded = True
|
|
106
|
+
|
|
107
|
+
x = batch[0]
|
|
108
|
+
# TODO: add TTA
|
|
109
|
+
prediction = self.model(x.data).cpu().numpy()
|
|
110
|
+
|
|
111
|
+
means = self._trainer.datamodule.stats.means
|
|
112
|
+
stds = self._trainer.datamodule.stats.stds
|
|
113
|
+
denormalize = Denormalize(
|
|
114
|
+
image_means=means,
|
|
115
|
+
image_stds=stds,
|
|
116
|
+
)
|
|
117
|
+
denormalized_output = denormalize(prediction)
|
|
118
|
+
|
|
119
|
+
output_batch = ImageRegionData(
|
|
120
|
+
data=denormalized_output,
|
|
121
|
+
source=x.source,
|
|
122
|
+
data_shape=x.data_shape,
|
|
123
|
+
dtype=x.dtype,
|
|
124
|
+
axes=x.axes,
|
|
125
|
+
region_spec=x.region_spec,
|
|
126
|
+
)
|
|
127
|
+
return output_batch
|
|
128
|
+
|
|
129
|
+
def configure_optimizers(self) -> Any:
|
|
130
|
+
"""Configure optimizers."""
|
|
131
|
+
optimizer_func = get_optimizer(self.config.optimizer.name)
|
|
132
|
+
optimizer = optimizer_func(
|
|
133
|
+
self.model.parameters(), **self.config.optimizer.parameters
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
scheduler_func = get_scheduler(self.config.lr_scheduler.name)
|
|
137
|
+
scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
|
|
138
|
+
|
|
139
|
+
return {
|
|
140
|
+
"optimizer": optimizer,
|
|
141
|
+
"lr_scheduler": scheduler,
|
|
142
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
143
|
+
}
|
|
@@ -148,6 +148,9 @@ class FCNModule(L.LightningModule):
|
|
|
148
148
|
self.log(
|
|
149
149
|
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
150
150
|
)
|
|
151
|
+
optimizer = self.optimizers()
|
|
152
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
153
|
+
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
151
154
|
return loss
|
|
152
155
|
|
|
153
156
|
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .config import DatasetConfig
|
|
2
2
|
from .lc_dataset import LCMultiChDloader
|
|
3
|
+
from .ms_dataset_ref import MultiChDloaderRef
|
|
4
|
+
from .multich_dataset import MultiChDloader
|
|
5
|
+
from .multicrop_dset import MultiCropDset
|
|
3
6
|
from .multifile_dataset import MultiFileDset
|
|
4
|
-
from .
|
|
5
|
-
from .types import DataType, DataSplitType, TilingMode
|
|
7
|
+
from .types import DataSplitType, DataType, TilingMode
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"DatasetConfig",
|
|
9
11
|
"MultiChDloader",
|
|
10
12
|
"LCMultiChDloader",
|
|
11
13
|
"MultiFileDset",
|
|
14
|
+
"MultiCropDset",
|
|
15
|
+
"MultiChDloaderRef",
|
|
16
|
+
"LCMultiChDloaderRef",
|
|
12
17
|
"DataType",
|
|
13
18
|
"DataSplitType",
|
|
14
19
|
"TilingMode",
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
@@ -43,7 +43,7 @@ class DatasetConfig(BaseModel):
|
|
|
43
43
|
image_size: tuple # TODO: revisit, new model_config uses tuple
|
|
44
44
|
"""Size of one patch of data"""
|
|
45
45
|
|
|
46
|
-
grid_size: Optional[int] = None
|
|
46
|
+
grid_size: Optional[Union[int, tuple[int, int, int]]] = None
|
|
47
47
|
"""Frame is divided into square grids of this size. A patch centered on a grid
|
|
48
48
|
having size `image_size` is returned. Grid size not used in training,
|
|
49
49
|
used only during val / test, grid size controls the overlap of the patches"""
|
|
@@ -82,7 +82,7 @@ class DatasetConfig(BaseModel):
|
|
|
82
82
|
# TODO: why is this not used?
|
|
83
83
|
enable_rotation_aug: Optional[bool] = False
|
|
84
84
|
|
|
85
|
-
max_val: Optional[float] = None
|
|
85
|
+
max_val: Optional[Union[float, tuple]] = None
|
|
86
86
|
"""Maximum data in the dataset. Is calculated for train split, and should be
|
|
87
87
|
externally set for val and test splits."""
|
|
88
88
|
|