careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/models/unet.py +16 -10
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,12 +16,15 @@ class SupportedData(str, BaseEnum):
|
|
|
16
16
|
Array data.
|
|
17
17
|
TIFF : str
|
|
18
18
|
TIFF image data.
|
|
19
|
+
CZI : str
|
|
20
|
+
CZI image data.
|
|
19
21
|
CUSTOM : str
|
|
20
22
|
Custom data.
|
|
21
23
|
"""
|
|
22
24
|
|
|
23
25
|
ARRAY = "array"
|
|
24
26
|
TIFF = "tiff"
|
|
27
|
+
CZI = "czi"
|
|
25
28
|
CUSTOM = "custom"
|
|
26
29
|
# ZARR = "zarr"
|
|
27
30
|
|
|
@@ -78,6 +81,8 @@ class SupportedData(str, BaseEnum):
|
|
|
78
81
|
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
79
82
|
elif data_type == cls.TIFF:
|
|
80
83
|
return "*.tif*"
|
|
84
|
+
elif data_type == cls.CZI:
|
|
85
|
+
return "*.czi"
|
|
81
86
|
elif data_type == cls.CUSTOM:
|
|
82
87
|
return "*.*"
|
|
83
88
|
else:
|
|
@@ -102,6 +107,8 @@ class SupportedData(str, BaseEnum):
|
|
|
102
107
|
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
103
108
|
elif data_type == cls.TIFF:
|
|
104
109
|
return ".tiff"
|
|
110
|
+
elif data_type == cls.CZI:
|
|
111
|
+
return ".czi"
|
|
105
112
|
elif data_type == cls.CUSTOM:
|
|
106
113
|
# TODO: improve this message
|
|
107
114
|
raise NotImplementedError("Custom extensions have to be passed elsewhere.")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Patching strategies supported by Careamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedPatchingStrategy(str, BaseEnum):
|
|
7
|
+
"""Patching strategies supported by Careamics."""
|
|
8
|
+
|
|
9
|
+
FIXED_RANDOM = "fixed_random"
|
|
10
|
+
"""Fixed random patching strategy, used during training."""
|
|
11
|
+
|
|
12
|
+
RANDOM = "random"
|
|
13
|
+
"""Random patching strategy, used during training."""
|
|
14
|
+
|
|
15
|
+
# SEQUENTIAL = "sequential"
|
|
16
|
+
# """Sequential patching strategy, used during training."""
|
|
17
|
+
|
|
18
|
+
TILED = "tiled"
|
|
19
|
+
"""Tiled patching strategy, used during prediction."""
|
|
20
|
+
|
|
21
|
+
WHOLE = "whole"
|
|
22
|
+
"""Whole image patching strategy, used during prediction."""
|
|
@@ -4,7 +4,8 @@ Validator functions.
|
|
|
4
4
|
These functions are used to validate dimensions and axes of inputs.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from typing import Optional
|
|
8
9
|
|
|
9
10
|
_AXES = "STCZYX"
|
|
10
11
|
|
|
@@ -79,14 +80,14 @@ def value_ge_than_8_power_of_2(
|
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
def patch_size_ge_than_8_power_of_2(
|
|
82
|
-
patch_list: Optional[
|
|
83
|
+
patch_list: Optional[Sequence[int]],
|
|
83
84
|
) -> None:
|
|
84
85
|
"""
|
|
85
86
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
86
87
|
|
|
87
88
|
Parameters
|
|
88
89
|
----------
|
|
89
|
-
patch_list :
|
|
90
|
+
patch_list : Sequence of int, or None
|
|
90
91
|
Patch size.
|
|
91
92
|
|
|
92
93
|
Raises
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Callable, Generator
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import get_worker_info
|
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import Any,
|
|
8
|
+
from typing import Any, Optional, Union
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
from torch.utils.data import Dataset
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Callable, Generator
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Optional
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from torch.utils.data import IterableDataset
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Callable, Generator
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import IterableDataset
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Callable, Generator
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import IterableDataset
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Patching functions."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
6
|
+
from typing import Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from numpy.typing import NDArray
|
|
@@ -89,7 +90,7 @@ def prepare_patches_supervised(
|
|
|
89
90
|
"""
|
|
90
91
|
means, stds, num_samples = 0, 0, 0
|
|
91
92
|
all_patches, all_targets = [], []
|
|
92
|
-
for train_filename, target_filename in zip(train_files, target_files):
|
|
93
|
+
for train_filename, target_filename in zip(train_files, target_files, strict=False):
|
|
93
94
|
try:
|
|
94
95
|
sample: np.ndarray = read_source_func(train_filename, axes)
|
|
95
96
|
target: np.ndarray = read_source_func(target_filename, axes)
|
|
@@ -78,7 +78,9 @@ def extract_tiles(
|
|
|
78
78
|
...,
|
|
79
79
|
*[
|
|
80
80
|
slice(coords, coords + extent)
|
|
81
|
-
for coords, extent in zip(
|
|
81
|
+
for coords, extent in zip(
|
|
82
|
+
crop_coords_start, tile_size, strict=False
|
|
83
|
+
)
|
|
82
84
|
],
|
|
83
85
|
)
|
|
84
86
|
tile = sample[crop_slices]
|
|
@@ -159,11 +161,14 @@ def compute_tile_info_legacy(
|
|
|
159
161
|
|
|
160
162
|
# --- combine start and end
|
|
161
163
|
stitch_coords = tuple(
|
|
162
|
-
(start, end)
|
|
164
|
+
(start, end)
|
|
165
|
+
for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
|
|
163
166
|
)
|
|
164
167
|
overlap_crop_coords = tuple(
|
|
165
168
|
(start, end)
|
|
166
|
-
for start, end in zip(
|
|
169
|
+
for start, end in zip(
|
|
170
|
+
overlap_crop_coords_start, overlap_crop_coords_end, strict=False
|
|
171
|
+
)
|
|
167
172
|
)
|
|
168
173
|
|
|
169
174
|
tile_info = TileInformation(
|
|
@@ -229,11 +234,14 @@ def compute_tile_info(
|
|
|
229
234
|
|
|
230
235
|
# --- combine start and end
|
|
231
236
|
stitch_coords = tuple(
|
|
232
|
-
(start, end)
|
|
237
|
+
(start, end)
|
|
238
|
+
for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
|
|
233
239
|
)
|
|
234
240
|
overlap_crop_coords = tuple(
|
|
235
241
|
(start, end)
|
|
236
|
-
for start, end in zip(
|
|
242
|
+
for start, end in zip(
|
|
243
|
+
overlap_crop_coords_start, overlap_crop_coords_end, strict=False
|
|
244
|
+
)
|
|
237
245
|
)
|
|
238
246
|
|
|
239
247
|
# --- Check if last tile
|
|
@@ -284,7 +292,9 @@ def compute_padding(
|
|
|
284
292
|
pad_before = overlaps // 2
|
|
285
293
|
pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
|
|
286
294
|
|
|
287
|
-
return tuple(
|
|
295
|
+
return tuple(
|
|
296
|
+
(before, after) for before, after in zip(pad_before, pad_after, strict=False)
|
|
297
|
+
)
|
|
288
298
|
|
|
289
299
|
|
|
290
300
|
def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
|
|
@@ -127,7 +127,7 @@ def extract_tiles(
|
|
|
127
127
|
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
128
128
|
# grouped by type.
|
|
129
129
|
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
130
|
-
*crop_and_stitch_coords_list
|
|
130
|
+
*crop_and_stitch_coords_list, strict=False
|
|
131
131
|
)
|
|
132
132
|
|
|
133
133
|
# Maximum tile index
|
|
@@ -139,6 +139,7 @@ def extract_tiles(
|
|
|
139
139
|
itertools.product(*all_crop_coords),
|
|
140
140
|
itertools.product(*all_stitch_coords),
|
|
141
141
|
itertools.product(*all_overlap_crop_coords),
|
|
142
|
+
strict=False,
|
|
142
143
|
)
|
|
143
144
|
):
|
|
144
145
|
# Extract tile from the sample
|
careamics/dataset_ng/dataset.py
CHANGED
|
@@ -8,7 +8,10 @@ from numpy.typing import NDArray
|
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
from tqdm.auto import tqdm
|
|
10
10
|
|
|
11
|
-
from careamics.config import
|
|
11
|
+
from careamics.config.data.ng_data_model import NGDataConfig
|
|
12
|
+
from careamics.config.support.supported_patching_strategies import (
|
|
13
|
+
SupportedPatchingStrategy,
|
|
14
|
+
)
|
|
12
15
|
from careamics.config.transformations import NormalizeModel
|
|
13
16
|
from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
|
|
14
17
|
from careamics.dataset.patching.patching import Stats
|
|
@@ -45,7 +48,7 @@ InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
|
|
|
45
48
|
class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
46
49
|
def __init__(
|
|
47
50
|
self,
|
|
48
|
-
data_config:
|
|
51
|
+
data_config: NGDataConfig,
|
|
49
52
|
mode: Mode,
|
|
50
53
|
input_extractor: PatchExtractor[GenericImageStack],
|
|
51
54
|
target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
|
|
@@ -65,33 +68,43 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
65
68
|
def _initialize_patching_strategy(self) -> PatchingStrategy:
|
|
66
69
|
patching_strategy: PatchingStrategy
|
|
67
70
|
if self.mode == Mode.TRAINING:
|
|
68
|
-
if
|
|
69
|
-
raise ValueError(
|
|
71
|
+
if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Only `random` patching strategy supported during training, got "
|
|
74
|
+
f"{self.config.patching.name}."
|
|
75
|
+
)
|
|
76
|
+
|
|
70
77
|
patching_strategy = RandomPatchingStrategy(
|
|
71
78
|
data_shapes=self.input_extractor.shape,
|
|
72
|
-
patch_size=self.config.patch_size,
|
|
73
|
-
|
|
74
|
-
seed=getattr(self.config, "random_seed", 42),
|
|
79
|
+
patch_size=self.config.patching.patch_size,
|
|
80
|
+
seed=self.config.seed,
|
|
75
81
|
)
|
|
76
82
|
elif self.mode == Mode.VALIDATING:
|
|
77
|
-
if
|
|
78
|
-
raise ValueError(
|
|
83
|
+
if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Only `random` patching strategy supported during training, got "
|
|
86
|
+
f"{self.config.patching.name}."
|
|
87
|
+
)
|
|
88
|
+
|
|
79
89
|
patching_strategy = FixedRandomPatchingStrategy(
|
|
80
90
|
data_shapes=self.input_extractor.shape,
|
|
81
|
-
patch_size=self.config.patch_size,
|
|
82
|
-
|
|
83
|
-
seed=getattr(self.config, "random_seed", 42),
|
|
91
|
+
patch_size=self.config.patching.patch_size,
|
|
92
|
+
seed=self.config.seed,
|
|
84
93
|
)
|
|
85
94
|
elif self.mode == Mode.PREDICTING:
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
self.config.tile_overlap is not None
|
|
95
|
+
if (
|
|
96
|
+
self.config.patching.name != SupportedPatchingStrategy.TILED
|
|
97
|
+
and self.config.patching.name != SupportedPatchingStrategy.WHOLE
|
|
90
98
|
):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Only `tiled` and `whole` patching strategy supported during "
|
|
101
|
+
f"training, got {self.config.patching.name}."
|
|
102
|
+
)
|
|
103
|
+
elif self.config.patching.name == SupportedPatchingStrategy.TILED:
|
|
91
104
|
patching_strategy = TilingStrategy(
|
|
92
105
|
data_shapes=self.input_extractor.shape,
|
|
93
|
-
tile_size=self.config.
|
|
94
|
-
overlaps=self.config.
|
|
106
|
+
tile_size=self.config.patching.patch_size,
|
|
107
|
+
overlaps=self.config.patching.overlaps,
|
|
95
108
|
)
|
|
96
109
|
else:
|
|
97
110
|
patching_strategy = WholeSamplePatchingStrategy(
|
|
@@ -103,32 +116,18 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
103
116
|
return patching_strategy
|
|
104
117
|
|
|
105
118
|
def _initialize_transforms(self) -> Optional[Compose]:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
target_stds=self.target_stats.stds,
|
|
116
|
-
)
|
|
117
|
-
]
|
|
118
|
-
+ list(self.config.transforms)
|
|
119
|
-
)
|
|
119
|
+
normalize = NormalizeModel(
|
|
120
|
+
image_means=self.input_stats.means,
|
|
121
|
+
image_stds=self.input_stats.stds,
|
|
122
|
+
target_means=self.target_stats.means,
|
|
123
|
+
target_stds=self.target_stats.stds,
|
|
124
|
+
)
|
|
125
|
+
if self.mode == Mode.TRAINING:
|
|
126
|
+
# TODO: initialize normalization separately depending on configuration
|
|
127
|
+
return Compose(transform_list=[normalize] + list(self.config.transforms))
|
|
120
128
|
|
|
121
129
|
# TODO: add TTA
|
|
122
|
-
return Compose(
|
|
123
|
-
transform_list=[
|
|
124
|
-
NormalizeModel(
|
|
125
|
-
image_means=self.input_stats.means,
|
|
126
|
-
image_stds=self.input_stats.stds,
|
|
127
|
-
target_means=self.target_stats.means,
|
|
128
|
-
target_stds=self.target_stats.stds,
|
|
129
|
-
)
|
|
130
|
-
]
|
|
131
|
-
)
|
|
130
|
+
return Compose(transform_list=[normalize])
|
|
132
131
|
|
|
133
132
|
def _calculate_stats(
|
|
134
133
|
self, data_extractor: PatchExtractor[GenericImageStack]
|
|
@@ -158,14 +157,11 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
158
157
|
input_stats = self._calculate_stats(self.input_extractor)
|
|
159
158
|
|
|
160
159
|
target_stats = Stats((), ())
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
)
|
|
166
|
-
target_stats = Stats(self.config.target_means, self.config.target_stds)
|
|
167
|
-
elif self.target_extractor is not None:
|
|
168
|
-
target_stats = self._calculate_stats(self.target_extractor)
|
|
160
|
+
|
|
161
|
+
if self.config.target_means is not None and self.config.target_stds is not None:
|
|
162
|
+
target_stats = Stats(self.config.target_means, self.config.target_stds)
|
|
163
|
+
elif self.target_extractor is not None:
|
|
164
|
+
target_stats = self._calculate_stats(self.target_extractor)
|
|
169
165
|
|
|
170
166
|
return input_stats, target_stats
|
|
171
167
|
|
|
@@ -13,8 +13,11 @@
|
|
|
13
13
|
"import tifffile\n",
|
|
14
14
|
"from careamics_portfolio import PortfolioManager\n",
|
|
15
15
|
"\n",
|
|
16
|
-
"from careamics.config.configuration_factories import
|
|
17
|
-
"
|
|
16
|
+
"from careamics.config.configuration_factories import (\n",
|
|
17
|
+
" _create_ng_data_configuration,\n",
|
|
18
|
+
" create_n2v_configuration,\n",
|
|
19
|
+
")\n",
|
|
20
|
+
"from careamics.config.data import NGDataConfig\n",
|
|
18
21
|
"from careamics.lightning.callbacks import HyperParametersCallback\n",
|
|
19
22
|
"from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
|
|
20
23
|
"from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
|
|
@@ -29,7 +32,8 @@
|
|
|
29
32
|
"# Set seeds for reproducibility\n",
|
|
30
33
|
"from pytorch_lightning import seed_everything\n",
|
|
31
34
|
"\n",
|
|
32
|
-
"
|
|
35
|
+
"seed = 42\n",
|
|
36
|
+
"seed_everything(seed)"
|
|
33
37
|
]
|
|
34
38
|
},
|
|
35
39
|
{
|
|
@@ -110,17 +114,17 @@
|
|
|
110
114
|
" num_epochs=100,\n",
|
|
111
115
|
")\n",
|
|
112
116
|
"\n",
|
|
113
|
-
"#
|
|
114
|
-
"
|
|
115
|
-
"
|
|
116
|
-
"
|
|
117
|
-
"
|
|
118
|
-
"
|
|
119
|
-
"
|
|
120
|
-
"
|
|
121
|
-
"
|
|
122
|
-
"
|
|
123
|
-
"
|
|
117
|
+
"# TODO until the NGDataConfig is accepted by the Confiugration, these are separte\n",
|
|
118
|
+
"ng_data_config = _create_ng_data_configuration(\n",
|
|
119
|
+
" data_type=config.data_config.data_type,\n",
|
|
120
|
+
" axes=config.data_config.axes,\n",
|
|
121
|
+
" patch_size=config.data_config.patch_size,\n",
|
|
122
|
+
" batch_size=config.data_config.batch_size,\n",
|
|
123
|
+
" augmentations=config.data_config.transforms,\n",
|
|
124
|
+
" train_dataloader_params=config.data_config.train_dataloader_params,\n",
|
|
125
|
+
" val_dataloader_params=config.data_config.val_dataloader_params,\n",
|
|
126
|
+
" seed=seed,\n",
|
|
127
|
+
")\n"
|
|
124
128
|
]
|
|
125
129
|
},
|
|
126
130
|
{
|
|
@@ -137,7 +141,7 @@
|
|
|
137
141
|
"outputs": [],
|
|
138
142
|
"source": [
|
|
139
143
|
"train_data_module = CareamicsDataModule(\n",
|
|
140
|
-
" data_config=
|
|
144
|
+
" data_config=ng_data_config,\n",
|
|
141
145
|
" train_data=train_files,\n",
|
|
142
146
|
" val_data=val_files,\n",
|
|
143
147
|
")\n",
|
|
@@ -224,15 +228,16 @@
|
|
|
224
228
|
"metadata": {},
|
|
225
229
|
"outputs": [],
|
|
226
230
|
"source": [
|
|
227
|
-
"from careamics.config.inference_model import InferenceConfig\n",
|
|
228
231
|
"from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
|
|
229
232
|
"from careamics.prediction_utils import convert_outputs\n",
|
|
230
233
|
"\n",
|
|
231
|
-
"config =
|
|
232
|
-
" model_config=config,\n",
|
|
234
|
+
"config = NGDataConfig(\n",
|
|
233
235
|
" data_type=\"tiff\",\n",
|
|
234
|
-
"
|
|
235
|
-
"
|
|
236
|
+
" patching={\n",
|
|
237
|
+
" \"name\": \"tiled\",\n",
|
|
238
|
+
" \"patch_size\": (128, 128),\n",
|
|
239
|
+
" \"overlaps\": (32, 32),\n",
|
|
240
|
+
" },\n",
|
|
236
241
|
" axes=\"YX\",\n",
|
|
237
242
|
" batch_size=1,\n",
|
|
238
243
|
" image_means=train_data_module.train_dataset.input_stats.means,\n",
|
|
@@ -319,7 +324,7 @@
|
|
|
319
324
|
"psnrs = np.zeros((len(predictions), 1))\n",
|
|
320
325
|
"scale_invariant_psnrs = np.zeros((len(predictions), 1))\n",
|
|
321
326
|
"\n",
|
|
322
|
-
"for i, (pred, gt) in enumerate(zip(predictions, gts)):\n",
|
|
327
|
+
"for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):\n",
|
|
323
328
|
" psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
|
|
324
329
|
" scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
|
|
325
330
|
"\n",
|
|
@@ -334,7 +339,7 @@
|
|
|
334
339
|
],
|
|
335
340
|
"metadata": {
|
|
336
341
|
"kernelspec": {
|
|
337
|
-
"display_name": "
|
|
342
|
+
"display_name": "czi",
|
|
338
343
|
"language": "python",
|
|
339
344
|
"name": "python3"
|
|
340
345
|
},
|
|
@@ -348,7 +353,7 @@
|
|
|
348
353
|
"name": "python",
|
|
349
354
|
"nbconvert_exporter": "python",
|
|
350
355
|
"pygments_lexer": "ipython3",
|
|
351
|
-
"version": "3.
|
|
356
|
+
"version": "3.12.11"
|
|
352
357
|
}
|
|
353
358
|
},
|
|
354
359
|
"nbformat": 4,
|
|
@@ -293,7 +293,7 @@
|
|
|
293
293
|
"psnrs = np.zeros((len(prediction), 1))\n",
|
|
294
294
|
"scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
|
|
295
295
|
"\n",
|
|
296
|
-
"for i, (pred, gt) in enumerate(zip(prediction, gts)):\n",
|
|
296
|
+
"for i, (pred, gt) in enumerate(zip(prediction, gts, strict=False)):\n",
|
|
297
297
|
" psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
|
|
298
298
|
" scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
|
|
299
299
|
"\n",
|
|
@@ -698,7 +698,7 @@
|
|
|
698
698
|
"psnrs = np.zeros((len(prediction), 1))\n",
|
|
699
699
|
"scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
|
|
700
700
|
"\n",
|
|
701
|
-
"for i, (pred, gt) in enumerate(zip(prediction, gts)):\n",
|
|
701
|
+
"for i, (pred, gt) in enumerate(zip(prediction, gts, strict=False)):\n",
|
|
702
702
|
" psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
|
|
703
703
|
" scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
|
|
704
704
|
"\n",
|