careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +24 -7
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +55 -4
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +41 -4
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/optimizer_models.py +1 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/training_model.py +0 -2
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +229 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +451 -0
- careamics/dataset_ng/legacy_interoperability.py +170 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +678 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
- careamics/lightning/lightning_module.py +5 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/compose.py +1 -0
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/normalize.py +18 -7
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +25 -11
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Generic, Literal, NamedTuple, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from careamics.config.data.ng_data_model import NGDataConfig
|
|
12
|
+
from careamics.config.support.supported_patching_strategies import (
|
|
13
|
+
SupportedPatchingStrategy,
|
|
14
|
+
)
|
|
15
|
+
from careamics.config.transformations import NormalizeModel
|
|
16
|
+
from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
|
|
17
|
+
from careamics.dataset.patching.patching import Stats
|
|
18
|
+
from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
|
|
19
|
+
from careamics.dataset_ng.patching_strategies import (
|
|
20
|
+
FixedRandomPatchingStrategy,
|
|
21
|
+
PatchingStrategy,
|
|
22
|
+
PatchSpecs,
|
|
23
|
+
RandomPatchingStrategy,
|
|
24
|
+
TilingStrategy,
|
|
25
|
+
WholeSamplePatchingStrategy,
|
|
26
|
+
)
|
|
27
|
+
from careamics.transforms import Compose
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Mode(str, Enum):
|
|
31
|
+
TRAINING = "training"
|
|
32
|
+
VALIDATING = "validating"
|
|
33
|
+
PREDICTING = "predicting"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ImageRegionData(NamedTuple):
|
|
37
|
+
data: NDArray
|
|
38
|
+
source: Union[str, Literal["array"]]
|
|
39
|
+
data_shape: Sequence[int]
|
|
40
|
+
dtype: str # dtype should be str for collate
|
|
41
|
+
axes: str
|
|
42
|
+
region_spec: PatchSpecs
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
data_config: NGDataConfig,
|
|
52
|
+
mode: Mode,
|
|
53
|
+
input_extractor: PatchExtractor[GenericImageStack],
|
|
54
|
+
target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
|
|
55
|
+
):
|
|
56
|
+
self.config = data_config
|
|
57
|
+
self.mode = mode
|
|
58
|
+
|
|
59
|
+
self.input_extractor = input_extractor
|
|
60
|
+
self.target_extractor = target_extractor
|
|
61
|
+
|
|
62
|
+
self.patching_strategy = self._initialize_patching_strategy()
|
|
63
|
+
|
|
64
|
+
self.input_stats, self.target_stats = self._initialize_statistics()
|
|
65
|
+
|
|
66
|
+
self.transforms = self._initialize_transforms()
|
|
67
|
+
|
|
68
|
+
def _initialize_patching_strategy(self) -> PatchingStrategy:
|
|
69
|
+
patching_strategy: PatchingStrategy
|
|
70
|
+
if self.mode == Mode.TRAINING:
|
|
71
|
+
if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Only `random` patching strategy supported during training, got "
|
|
74
|
+
f"{self.config.patching.name}."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
patching_strategy = RandomPatchingStrategy(
|
|
78
|
+
data_shapes=self.input_extractor.shape,
|
|
79
|
+
patch_size=self.config.patching.patch_size,
|
|
80
|
+
seed=self.config.seed,
|
|
81
|
+
)
|
|
82
|
+
elif self.mode == Mode.VALIDATING:
|
|
83
|
+
if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Only `random` patching strategy supported during training, got "
|
|
86
|
+
f"{self.config.patching.name}."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
patching_strategy = FixedRandomPatchingStrategy(
|
|
90
|
+
data_shapes=self.input_extractor.shape,
|
|
91
|
+
patch_size=self.config.patching.patch_size,
|
|
92
|
+
seed=self.config.seed,
|
|
93
|
+
)
|
|
94
|
+
elif self.mode == Mode.PREDICTING:
|
|
95
|
+
if (
|
|
96
|
+
self.config.patching.name != SupportedPatchingStrategy.TILED
|
|
97
|
+
and self.config.patching.name != SupportedPatchingStrategy.WHOLE
|
|
98
|
+
):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Only `tiled` and `whole` patching strategy supported during "
|
|
101
|
+
f"training, got {self.config.patching.name}."
|
|
102
|
+
)
|
|
103
|
+
elif self.config.patching.name == SupportedPatchingStrategy.TILED:
|
|
104
|
+
patching_strategy = TilingStrategy(
|
|
105
|
+
data_shapes=self.input_extractor.shape,
|
|
106
|
+
tile_size=self.config.patching.patch_size,
|
|
107
|
+
overlaps=self.config.patching.overlaps,
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
patching_strategy = WholeSamplePatchingStrategy(
|
|
111
|
+
data_shapes=self.input_extractor.shape
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unrecognised dataset mode {self.mode}.")
|
|
115
|
+
|
|
116
|
+
return patching_strategy
|
|
117
|
+
|
|
118
|
+
def _initialize_transforms(self) -> Optional[Compose]:
|
|
119
|
+
normalize = NormalizeModel(
|
|
120
|
+
image_means=self.input_stats.means,
|
|
121
|
+
image_stds=self.input_stats.stds,
|
|
122
|
+
target_means=self.target_stats.means,
|
|
123
|
+
target_stds=self.target_stats.stds,
|
|
124
|
+
)
|
|
125
|
+
if self.mode == Mode.TRAINING:
|
|
126
|
+
# TODO: initialize normalization separately depending on configuration
|
|
127
|
+
return Compose(transform_list=[normalize] + list(self.config.transforms))
|
|
128
|
+
|
|
129
|
+
# TODO: add TTA
|
|
130
|
+
return Compose(transform_list=[normalize])
|
|
131
|
+
|
|
132
|
+
def _calculate_stats(
|
|
133
|
+
self, data_extractor: PatchExtractor[GenericImageStack]
|
|
134
|
+
) -> Stats:
|
|
135
|
+
image_stats = WelfordStatistics()
|
|
136
|
+
n_patches = self.patching_strategy.n_patches
|
|
137
|
+
|
|
138
|
+
for idx in tqdm(range(n_patches), desc="Computing statistics"):
|
|
139
|
+
patch_spec = self.patching_strategy.get_patch_spec(idx)
|
|
140
|
+
patch = data_extractor.extract_patch(
|
|
141
|
+
data_idx=patch_spec["data_idx"],
|
|
142
|
+
sample_idx=patch_spec["sample_idx"],
|
|
143
|
+
coords=patch_spec["coords"],
|
|
144
|
+
patch_size=patch_spec["patch_size"],
|
|
145
|
+
)
|
|
146
|
+
# TODO: statistics accept SCYX format, while patch is CYX
|
|
147
|
+
image_stats.update(patch[None, ...], sample_idx=idx)
|
|
148
|
+
|
|
149
|
+
image_means, image_stds = image_stats.finalize()
|
|
150
|
+
return Stats(image_means, image_stds)
|
|
151
|
+
|
|
152
|
+
# TODO: add running stats
|
|
153
|
+
def _initialize_statistics(self) -> tuple[Stats, Stats]:
|
|
154
|
+
if self.config.image_means is not None and self.config.image_stds is not None:
|
|
155
|
+
input_stats = Stats(self.config.image_means, self.config.image_stds)
|
|
156
|
+
else:
|
|
157
|
+
input_stats = self._calculate_stats(self.input_extractor)
|
|
158
|
+
|
|
159
|
+
target_stats = Stats((), ())
|
|
160
|
+
|
|
161
|
+
if self.config.target_means is not None and self.config.target_stds is not None:
|
|
162
|
+
target_stats = Stats(self.config.target_means, self.config.target_stds)
|
|
163
|
+
elif self.target_extractor is not None:
|
|
164
|
+
target_stats = self._calculate_stats(self.target_extractor)
|
|
165
|
+
|
|
166
|
+
return input_stats, target_stats
|
|
167
|
+
|
|
168
|
+
def __len__(self):
|
|
169
|
+
return self.patching_strategy.n_patches
|
|
170
|
+
|
|
171
|
+
def _create_image_region(
|
|
172
|
+
self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
|
|
173
|
+
) -> ImageRegionData:
|
|
174
|
+
data_idx = patch_spec["data_idx"]
|
|
175
|
+
source = extractor.image_stacks[data_idx].source
|
|
176
|
+
return ImageRegionData(
|
|
177
|
+
data=patch,
|
|
178
|
+
source=str(source),
|
|
179
|
+
dtype=str(extractor.image_stacks[data_idx].data_dtype),
|
|
180
|
+
data_shape=extractor.image_stacks[data_idx].data_shape,
|
|
181
|
+
# TODO: should it be axes of the original image instead?
|
|
182
|
+
axes=self.config.axes,
|
|
183
|
+
region_spec=patch_spec,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def __getitem__(
|
|
187
|
+
self, index: int
|
|
188
|
+
) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
|
|
189
|
+
patch_spec = self.patching_strategy.get_patch_spec(index)
|
|
190
|
+
input_patch = self.input_extractor.extract_patch(
|
|
191
|
+
data_idx=patch_spec["data_idx"],
|
|
192
|
+
sample_idx=patch_spec["sample_idx"],
|
|
193
|
+
coords=patch_spec["coords"],
|
|
194
|
+
patch_size=patch_spec["patch_size"],
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
target_patch = (
|
|
198
|
+
self.target_extractor.extract_patch(
|
|
199
|
+
data_idx=patch_spec["data_idx"],
|
|
200
|
+
sample_idx=patch_spec["sample_idx"],
|
|
201
|
+
coords=patch_spec["coords"],
|
|
202
|
+
patch_size=patch_spec["patch_size"],
|
|
203
|
+
)
|
|
204
|
+
if self.target_extractor is not None
|
|
205
|
+
else None
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if self.transforms is not None:
|
|
209
|
+
if self.target_extractor is not None:
|
|
210
|
+
input_patch, target_patch = self.transforms(input_patch, target_patch)
|
|
211
|
+
else:
|
|
212
|
+
# TODO: compose doesn't return None for target patch anymore
|
|
213
|
+
# so have to do this annoying if else
|
|
214
|
+
(input_patch,) = self.transforms(input_patch, target_patch)
|
|
215
|
+
target_patch = None
|
|
216
|
+
|
|
217
|
+
input_data = self._create_image_region(
|
|
218
|
+
patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if target_patch is not None and self.target_extractor is not None:
|
|
222
|
+
target_data = self._create_image_region(
|
|
223
|
+
patch=target_patch,
|
|
224
|
+
patch_spec=patch_spec,
|
|
225
|
+
extractor=self.target_extractor,
|
|
226
|
+
)
|
|
227
|
+
return input_data, target_data
|
|
228
|
+
else:
|
|
229
|
+
return (input_data,)
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "code",
|
|
5
|
+
"execution_count": null,
|
|
6
|
+
"metadata": {},
|
|
7
|
+
"outputs": [],
|
|
8
|
+
"source": [
|
|
9
|
+
"from pathlib import Path\n",
|
|
10
|
+
"\n",
|
|
11
|
+
"import matplotlib.pyplot as plt\n",
|
|
12
|
+
"import numpy as np\n",
|
|
13
|
+
"import tifffile\n",
|
|
14
|
+
"from careamics_portfolio import PortfolioManager\n",
|
|
15
|
+
"\n",
|
|
16
|
+
"from careamics.config.configuration_factories import (\n",
|
|
17
|
+
" _create_ng_data_configuration,\n",
|
|
18
|
+
" create_n2v_configuration,\n",
|
|
19
|
+
")\n",
|
|
20
|
+
"from careamics.config.data import NGDataConfig\n",
|
|
21
|
+
"from careamics.lightning.callbacks import HyperParametersCallback\n",
|
|
22
|
+
"from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
|
|
23
|
+
"from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
|
|
24
|
+
]
|
|
25
|
+
},
|
|
26
|
+
{
|
|
27
|
+
"cell_type": "code",
|
|
28
|
+
"execution_count": null,
|
|
29
|
+
"metadata": {},
|
|
30
|
+
"outputs": [],
|
|
31
|
+
"source": [
|
|
32
|
+
"# Set seeds for reproducibility\n",
|
|
33
|
+
"from pytorch_lightning import seed_everything\n",
|
|
34
|
+
"\n",
|
|
35
|
+
"seed = 42\n",
|
|
36
|
+
"seed_everything(seed)"
|
|
37
|
+
]
|
|
38
|
+
},
|
|
39
|
+
{
|
|
40
|
+
"cell_type": "markdown",
|
|
41
|
+
"metadata": {},
|
|
42
|
+
"source": [
|
|
43
|
+
"### Load data and set paths to it"
|
|
44
|
+
]
|
|
45
|
+
},
|
|
46
|
+
{
|
|
47
|
+
"cell_type": "code",
|
|
48
|
+
"execution_count": null,
|
|
49
|
+
"metadata": {},
|
|
50
|
+
"outputs": [],
|
|
51
|
+
"source": [
|
|
52
|
+
"# instantiate data portfolio manage and download the data\n",
|
|
53
|
+
"root_path = Path(\"./data\")\n",
|
|
54
|
+
"\n",
|
|
55
|
+
"portfolio = PortfolioManager()\n",
|
|
56
|
+
"files = portfolio.denoising.N2V_BSD68.download(root_path)\n",
|
|
57
|
+
"\n",
|
|
58
|
+
"# create paths for the data\n",
|
|
59
|
+
"data_path = Path(root_path / \"denoising-N2V_BSD68.unzip/BSD68_reproducibility_data\")\n",
|
|
60
|
+
"train_path = data_path / \"train\"\n",
|
|
61
|
+
"val_path = data_path / \"val\"\n",
|
|
62
|
+
"test_path = data_path / \"test\" / \"images\"\n",
|
|
63
|
+
"gt_path = data_path / \"test\" / \"gt\"\n",
|
|
64
|
+
"\n",
|
|
65
|
+
"# list train, val and test files\n",
|
|
66
|
+
"train_files = sorted(train_path.rglob(\"*.tiff\"))\n",
|
|
67
|
+
"val_files = sorted(val_path.rglob(\"*.tiff\"))\n",
|
|
68
|
+
"test_files = sorted(test_path.rglob(\"*.tiff\"))"
|
|
69
|
+
]
|
|
70
|
+
},
|
|
71
|
+
{
|
|
72
|
+
"cell_type": "markdown",
|
|
73
|
+
"metadata": {},
|
|
74
|
+
"source": [
|
|
75
|
+
"### Visualize a single train and val image"
|
|
76
|
+
]
|
|
77
|
+
},
|
|
78
|
+
{
|
|
79
|
+
"cell_type": "code",
|
|
80
|
+
"execution_count": null,
|
|
81
|
+
"metadata": {},
|
|
82
|
+
"outputs": [],
|
|
83
|
+
"source": [
|
|
84
|
+
"# load training and validation image and show them side by side\n",
|
|
85
|
+
"single_train_image = tifffile.imread(train_files[0])[0]\n",
|
|
86
|
+
"single_val_image = tifffile.imread(val_files[0])[0]\n",
|
|
87
|
+
"\n",
|
|
88
|
+
"fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
|
|
89
|
+
"ax[0].imshow(single_train_image, cmap=\"gray\")\n",
|
|
90
|
+
"ax[0].set_title(\"Training Image\")\n",
|
|
91
|
+
"ax[1].imshow(single_val_image, cmap=\"gray\")\n",
|
|
92
|
+
"ax[1].set_title(\"Validation Image\")"
|
|
93
|
+
]
|
|
94
|
+
},
|
|
95
|
+
{
|
|
96
|
+
"cell_type": "markdown",
|
|
97
|
+
"metadata": {},
|
|
98
|
+
"source": [
|
|
99
|
+
"### Create config"
|
|
100
|
+
]
|
|
101
|
+
},
|
|
102
|
+
{
|
|
103
|
+
"cell_type": "code",
|
|
104
|
+
"execution_count": null,
|
|
105
|
+
"metadata": {},
|
|
106
|
+
"outputs": [],
|
|
107
|
+
"source": [
|
|
108
|
+
"config = create_n2v_configuration(\n",
|
|
109
|
+
" experiment_name=\"bsd68_n2v\",\n",
|
|
110
|
+
" data_type=\"tiff\",\n",
|
|
111
|
+
" axes=\"SYX\",\n",
|
|
112
|
+
" patch_size=(64, 64),\n",
|
|
113
|
+
" batch_size=64,\n",
|
|
114
|
+
" num_epochs=100,\n",
|
|
115
|
+
")\n",
|
|
116
|
+
"\n",
|
|
117
|
+
"# TODO until the NGDataConfig is accepted by the Confiugration, these are separte\n",
|
|
118
|
+
"ng_data_config = _create_ng_data_configuration(\n",
|
|
119
|
+
" data_type=config.data_config.data_type,\n",
|
|
120
|
+
" axes=config.data_config.axes,\n",
|
|
121
|
+
" patch_size=config.data_config.patch_size,\n",
|
|
122
|
+
" batch_size=config.data_config.batch_size,\n",
|
|
123
|
+
" augmentations=config.data_config.transforms,\n",
|
|
124
|
+
" train_dataloader_params=config.data_config.train_dataloader_params,\n",
|
|
125
|
+
" val_dataloader_params=config.data_config.val_dataloader_params,\n",
|
|
126
|
+
" seed=seed,\n",
|
|
127
|
+
")\n"
|
|
128
|
+
]
|
|
129
|
+
},
|
|
130
|
+
{
|
|
131
|
+
"cell_type": "markdown",
|
|
132
|
+
"metadata": {},
|
|
133
|
+
"source": [
|
|
134
|
+
"### Create Lightning datamodule and model"
|
|
135
|
+
]
|
|
136
|
+
},
|
|
137
|
+
{
|
|
138
|
+
"cell_type": "code",
|
|
139
|
+
"execution_count": null,
|
|
140
|
+
"metadata": {},
|
|
141
|
+
"outputs": [],
|
|
142
|
+
"source": [
|
|
143
|
+
"train_data_module = CareamicsDataModule(\n",
|
|
144
|
+
" data_config=ng_data_config,\n",
|
|
145
|
+
" train_data=train_files,\n",
|
|
146
|
+
" val_data=val_files,\n",
|
|
147
|
+
")\n",
|
|
148
|
+
"\n",
|
|
149
|
+
"model = N2VModule(config.algorithm_config)"
|
|
150
|
+
]
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"cell_type": "markdown",
|
|
154
|
+
"metadata": {},
|
|
155
|
+
"source": [
|
|
156
|
+
"### Manually initialize the datamodule and visualize single train and val batches"
|
|
157
|
+
]
|
|
158
|
+
},
|
|
159
|
+
{
|
|
160
|
+
"cell_type": "code",
|
|
161
|
+
"execution_count": null,
|
|
162
|
+
"metadata": {},
|
|
163
|
+
"outputs": [],
|
|
164
|
+
"source": [
|
|
165
|
+
"train_data_module.setup(\"fit\")\n",
|
|
166
|
+
"train_data_module.setup(\"validate\")\n",
|
|
167
|
+
"\n",
|
|
168
|
+
"train_batch = next(iter(train_data_module.train_dataloader()))\n",
|
|
169
|
+
"val_batch = next(iter(train_data_module.val_dataloader()))\n",
|
|
170
|
+
"\n",
|
|
171
|
+
"fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
|
|
172
|
+
"ax[0].set_title(\"Training Batch\")\n",
|
|
173
|
+
"for i in range(8):\n",
|
|
174
|
+
" ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap=\"gray\")\n",
|
|
175
|
+
"\n",
|
|
176
|
+
"fig, ax = plt.subplots(1, 8, figsize=(10, 5))\n",
|
|
177
|
+
"ax[0].set_title(\"Validation Batch\")\n",
|
|
178
|
+
"for i in range(8):\n",
|
|
179
|
+
" ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap=\"gray\")"
|
|
180
|
+
]
|
|
181
|
+
},
|
|
182
|
+
{
|
|
183
|
+
"cell_type": "markdown",
|
|
184
|
+
"metadata": {},
|
|
185
|
+
"source": [
|
|
186
|
+
"### Train the model"
|
|
187
|
+
]
|
|
188
|
+
},
|
|
189
|
+
{
|
|
190
|
+
"cell_type": "code",
|
|
191
|
+
"execution_count": null,
|
|
192
|
+
"metadata": {},
|
|
193
|
+
"outputs": [],
|
|
194
|
+
"source": [
|
|
195
|
+
"from pytorch_lightning import Trainer\n",
|
|
196
|
+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
|
197
|
+
"from pytorch_lightning.loggers import WandbLogger\n",
|
|
198
|
+
"\n",
|
|
199
|
+
"root = Path(\"bsd68_n2v\")\n",
|
|
200
|
+
"callbacks = [\n",
|
|
201
|
+
" ModelCheckpoint(\n",
|
|
202
|
+
" dirpath=root / \"checkpoints\",\n",
|
|
203
|
+
" filename=\"bsd68_new_lightning_module\",\n",
|
|
204
|
+
" save_last=True,\n",
|
|
205
|
+
" monitor=\"val_loss\",\n",
|
|
206
|
+
" mode=\"min\",\n",
|
|
207
|
+
" ),\n",
|
|
208
|
+
" HyperParametersCallback(config),\n",
|
|
209
|
+
"]\n",
|
|
210
|
+
"logger = WandbLogger(project=\"bsd68-n2v\", name=\"bsd68_new_lightning_module\")\n",
|
|
211
|
+
"\n",
|
|
212
|
+
"trainer = Trainer(\n",
|
|
213
|
+
" max_epochs=50, default_root_dir=root, callbacks=callbacks, logger=logger\n",
|
|
214
|
+
")\n",
|
|
215
|
+
"trainer.fit(model, datamodule=train_data_module)"
|
|
216
|
+
]
|
|
217
|
+
},
|
|
218
|
+
{
|
|
219
|
+
"cell_type": "markdown",
|
|
220
|
+
"metadata": {},
|
|
221
|
+
"source": [
|
|
222
|
+
"### Create an inference config and datamodule"
|
|
223
|
+
]
|
|
224
|
+
},
|
|
225
|
+
{
|
|
226
|
+
"cell_type": "code",
|
|
227
|
+
"execution_count": null,
|
|
228
|
+
"metadata": {},
|
|
229
|
+
"outputs": [],
|
|
230
|
+
"source": [
|
|
231
|
+
"from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
|
|
232
|
+
"from careamics.prediction_utils import convert_outputs\n",
|
|
233
|
+
"\n",
|
|
234
|
+
"config = NGDataConfig(\n",
|
|
235
|
+
" data_type=\"tiff\",\n",
|
|
236
|
+
" patching={\n",
|
|
237
|
+
" \"name\": \"tiled\",\n",
|
|
238
|
+
" \"patch_size\": (128, 128),\n",
|
|
239
|
+
" \"overlaps\": (32, 32),\n",
|
|
240
|
+
" },\n",
|
|
241
|
+
" axes=\"YX\",\n",
|
|
242
|
+
" batch_size=1,\n",
|
|
243
|
+
" image_means=train_data_module.train_dataset.input_stats.means,\n",
|
|
244
|
+
" image_stds=train_data_module.train_dataset.input_stats.stds,\n",
|
|
245
|
+
")\n",
|
|
246
|
+
"\n",
|
|
247
|
+
"inf_data_module = CareamicsDataModule(data_config=config, pred_data=test_files)"
|
|
248
|
+
]
|
|
249
|
+
},
|
|
250
|
+
{
|
|
251
|
+
"cell_type": "markdown",
|
|
252
|
+
"metadata": {},
|
|
253
|
+
"source": [
|
|
254
|
+
"### Convert outputs to the legacy format and stitch the tiles"
|
|
255
|
+
]
|
|
256
|
+
},
|
|
257
|
+
{
|
|
258
|
+
"cell_type": "code",
|
|
259
|
+
"execution_count": null,
|
|
260
|
+
"metadata": {},
|
|
261
|
+
"outputs": [],
|
|
262
|
+
"source": [
|
|
263
|
+
"predictions = trainer.predict(model, datamodule=inf_data_module)\n",
|
|
264
|
+
"tile_infos = imageregions_to_tileinfos(predictions)\n",
|
|
265
|
+
"predictions = convert_outputs(tile_infos, tiled=True)"
|
|
266
|
+
]
|
|
267
|
+
},
|
|
268
|
+
{
|
|
269
|
+
"cell_type": "markdown",
|
|
270
|
+
"metadata": {},
|
|
271
|
+
"source": [
|
|
272
|
+
"### Visualize predictions and count metrics"
|
|
273
|
+
]
|
|
274
|
+
},
|
|
275
|
+
{
|
|
276
|
+
"cell_type": "code",
|
|
277
|
+
"execution_count": null,
|
|
278
|
+
"metadata": {},
|
|
279
|
+
"outputs": [],
|
|
280
|
+
"source": [
|
|
281
|
+
"from careamics.utils.metrics import psnr, scale_invariant_psnr\n",
|
|
282
|
+
"\n",
|
|
283
|
+
"noises = [tifffile.imread(f) for f in sorted(test_path.glob(\"*.tiff\"))]\n",
|
|
284
|
+
"gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n",
|
|
285
|
+
"\n",
|
|
286
|
+
"images = [0, 1, 2]\n",
|
|
287
|
+
"fig, ax = plt.subplots(3, 3, figsize=(15, 15))\n",
|
|
288
|
+
"fig.tight_layout()\n",
|
|
289
|
+
"\n",
|
|
290
|
+
"for i in range(3):\n",
|
|
291
|
+
" pred_image = predictions[images[i]].squeeze()\n",
|
|
292
|
+
" psnr_noisy = psnr(\n",
|
|
293
|
+
" gts[images[i]],\n",
|
|
294
|
+
" noises[images[i]],\n",
|
|
295
|
+
" data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
|
|
296
|
+
" )\n",
|
|
297
|
+
" psnr_result = psnr(\n",
|
|
298
|
+
" gts[images[i]],\n",
|
|
299
|
+
" pred_image,\n",
|
|
300
|
+
" data_range=gts[images[i]].max() - gts[images[i]].min(),\n",
|
|
301
|
+
" )\n",
|
|
302
|
+
"\n",
|
|
303
|
+
" scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)\n",
|
|
304
|
+
"\n",
|
|
305
|
+
" ax[i, 0].imshow(noises[images[i]], cmap=\"gray\")\n",
|
|
306
|
+
" ax[i, 0].title.set_text(f\"Noisy\\nPSNR: {psnr_noisy:.2f}\")\n",
|
|
307
|
+
"\n",
|
|
308
|
+
" ax[i, 1].imshow(pred_image, cmap=\"gray\")\n",
|
|
309
|
+
" ax[i, 1].title.set_text(\n",
|
|
310
|
+
" f\"Prediction\\nPSNR: {psnr_result:.2f}\\n\"\n",
|
|
311
|
+
" f\"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}\"\n",
|
|
312
|
+
" )\n",
|
|
313
|
+
"\n",
|
|
314
|
+
" ax[i, 2].imshow(gts[images[i]], cmap=\"gray\")\n",
|
|
315
|
+
" ax[i, 2].title.set_text(\"Ground-truth\")"
|
|
316
|
+
]
|
|
317
|
+
},
|
|
318
|
+
{
|
|
319
|
+
"cell_type": "code",
|
|
320
|
+
"execution_count": null,
|
|
321
|
+
"metadata": {},
|
|
322
|
+
"outputs": [],
|
|
323
|
+
"source": [
|
|
324
|
+
"psnrs = np.zeros((len(predictions), 1))\n",
|
|
325
|
+
"scale_invariant_psnrs = np.zeros((len(predictions), 1))\n",
|
|
326
|
+
"\n",
|
|
327
|
+
"for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):\n",
|
|
328
|
+
" psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
|
|
329
|
+
" scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
|
|
330
|
+
"\n",
|
|
331
|
+
"print(f\"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}\")\n",
|
|
332
|
+
"print(\n",
|
|
333
|
+
" f\"Scale invariant PSNR: \"\n",
|
|
334
|
+
" f\"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}\"\n",
|
|
335
|
+
")\n",
|
|
336
|
+
"print(\"Reported PSNR: 27.71\")"
|
|
337
|
+
]
|
|
338
|
+
}
|
|
339
|
+
],
|
|
340
|
+
"metadata": {
|
|
341
|
+
"kernelspec": {
|
|
342
|
+
"display_name": "czi",
|
|
343
|
+
"language": "python",
|
|
344
|
+
"name": "python3"
|
|
345
|
+
},
|
|
346
|
+
"language_info": {
|
|
347
|
+
"codemirror_mode": {
|
|
348
|
+
"name": "ipython",
|
|
349
|
+
"version": 3
|
|
350
|
+
},
|
|
351
|
+
"file_extension": ".py",
|
|
352
|
+
"mimetype": "text/x-python",
|
|
353
|
+
"name": "python",
|
|
354
|
+
"nbconvert_exporter": "python",
|
|
355
|
+
"pygments_lexer": "ipython3",
|
|
356
|
+
"version": "3.12.11"
|
|
357
|
+
}
|
|
358
|
+
},
|
|
359
|
+
"nbformat": 4,
|
|
360
|
+
"nbformat_minor": 2
|
|
361
|
+
}
|