careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""In-memory tiled prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..config.transformations import NormalizeModel
|
|
13
|
+
from .dataset_utils import reshape_array
|
|
14
|
+
from .tiling import extract_tiles
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InMemoryTiledPredDataset(Dataset):
|
|
18
|
+
"""Prediction dataset storing data in memory and returning tiles of each image.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
prediction_config : InferenceConfig
|
|
23
|
+
Prediction configuration.
|
|
24
|
+
inputs : NDArray
|
|
25
|
+
Input data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
prediction_config: InferenceConfig,
|
|
31
|
+
inputs: NDArray,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
prediction_config : InferenceConfig
|
|
38
|
+
Prediction configuration.
|
|
39
|
+
inputs : NDArray
|
|
40
|
+
Input data.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
ValueError
|
|
45
|
+
If data_path is not a directory.
|
|
46
|
+
"""
|
|
47
|
+
if (
|
|
48
|
+
prediction_config.tile_size is None
|
|
49
|
+
or prediction_config.tile_overlap is None
|
|
50
|
+
):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Tile size and overlap must be provided to use the tiled prediction "
|
|
53
|
+
"dataset."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.pred_config = prediction_config
|
|
57
|
+
self.input_array = inputs
|
|
58
|
+
self.axes = self.pred_config.axes
|
|
59
|
+
self.tile_size = prediction_config.tile_size
|
|
60
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
61
|
+
self.image_means = self.pred_config.image_means
|
|
62
|
+
self.image_stds = self.pred_config.image_stds
|
|
63
|
+
|
|
64
|
+
# Generate patches
|
|
65
|
+
self.data = self._prepare_tiles()
|
|
66
|
+
|
|
67
|
+
# get transforms
|
|
68
|
+
self.patch_transform = Compose(
|
|
69
|
+
transform_list=[
|
|
70
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
|
|
75
|
+
"""
|
|
76
|
+
Iterate over data source and create an array of patches.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
list of tuples of NDArray and TileInformation
|
|
81
|
+
List of tiles and tile information.
|
|
82
|
+
"""
|
|
83
|
+
# reshape array
|
|
84
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
85
|
+
|
|
86
|
+
# generate patches, which returns a generator
|
|
87
|
+
patch_generator = extract_tiles(
|
|
88
|
+
arr=reshaped_sample,
|
|
89
|
+
tile_size=self.tile_size,
|
|
90
|
+
overlaps=self.tile_overlap,
|
|
91
|
+
)
|
|
92
|
+
patches_list = list(patch_generator)
|
|
93
|
+
|
|
94
|
+
if len(patches_list) == 0:
|
|
95
|
+
raise ValueError("No tiles generated, ")
|
|
96
|
+
|
|
97
|
+
return patches_list
|
|
98
|
+
|
|
99
|
+
def __len__(self) -> int:
|
|
100
|
+
"""
|
|
101
|
+
Return the length of the dataset.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
int
|
|
106
|
+
Length of the dataset.
|
|
107
|
+
"""
|
|
108
|
+
return len(self.data)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
111
|
+
"""
|
|
112
|
+
Return the patch corresponding to the provided index.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
index : int
|
|
117
|
+
Index of the patch to return.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple of NDArray and TileInformation
|
|
122
|
+
Transformed patch.
|
|
123
|
+
"""
|
|
124
|
+
tile_array, tile_info = self.data[index]
|
|
125
|
+
|
|
126
|
+
# Apply transforms
|
|
127
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
128
|
+
|
|
129
|
+
return transformed_tile, tile_info
|
|
@@ -1,20 +1,27 @@
|
|
|
1
|
+
"""Iterable dataset used to load data file by file."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import copy
|
|
6
|
+
from collections.abc import Generator
|
|
4
7
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
8
|
+
from typing import Callable, Optional
|
|
6
9
|
|
|
7
10
|
import numpy as np
|
|
8
|
-
from torch.utils.data import IterableDataset
|
|
11
|
+
from torch.utils.data import IterableDataset
|
|
9
12
|
|
|
13
|
+
from careamics.config import DataConfig
|
|
14
|
+
from careamics.config.transformations import NormalizeModel
|
|
10
15
|
from careamics.transforms import Compose
|
|
11
16
|
|
|
12
|
-
from ..config import DataConfig, InferenceConfig
|
|
13
|
-
from ..config.tile_information import TileInformation
|
|
14
17
|
from ..utils.logging import get_logger
|
|
15
|
-
from .dataset_utils import
|
|
18
|
+
from .dataset_utils import (
|
|
19
|
+
iterate_over_files,
|
|
20
|
+
read_tiff,
|
|
21
|
+
)
|
|
22
|
+
from .dataset_utils.running_stats import WelfordStatistics
|
|
23
|
+
from .patching.patching import Stats, StatsOutput
|
|
16
24
|
from .patching.random_patching import extract_patches_random
|
|
17
|
-
from .patching.tiled_patching import extract_tiles
|
|
18
25
|
|
|
19
26
|
logger = get_logger(__name__)
|
|
20
27
|
|
|
@@ -25,129 +32,142 @@ class PathIterableDataset(IterableDataset):
|
|
|
25
32
|
|
|
26
33
|
Parameters
|
|
27
34
|
----------
|
|
28
|
-
|
|
35
|
+
data_config : DataConfig
|
|
36
|
+
Data configuration.
|
|
37
|
+
src_files : list of pathlib.Path
|
|
38
|
+
List of data files.
|
|
39
|
+
target_files : list of pathlib.Path, optional
|
|
40
|
+
Optional list of target files, by default None.
|
|
41
|
+
read_source_func : Callable, optional
|
|
42
|
+
Read source function for custom types, by default read_tiff.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
data_path : list of pathlib.Path
|
|
29
47
|
Path to the data, must be a directory.
|
|
30
48
|
axes : str
|
|
31
49
|
Description of axes in format STCZYX.
|
|
32
|
-
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
33
|
-
Patch extraction strategy, as defined in extraction_strategy.
|
|
34
|
-
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
35
|
-
Size of the patches in each dimension, by default None.
|
|
36
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
37
|
-
Overlap of the patches in each dimension, by default None.
|
|
38
|
-
mean : Optional[float], optional
|
|
39
|
-
Expected mean of the dataset, by default None.
|
|
40
|
-
std : Optional[float], optional
|
|
41
|
-
Expected standard deviation of the dataset, by default None.
|
|
42
|
-
patch_transform : Optional[Callable], optional
|
|
43
|
-
Patch transform callable, by default None.
|
|
44
50
|
"""
|
|
45
51
|
|
|
46
52
|
def __init__(
|
|
47
53
|
self,
|
|
48
|
-
data_config:
|
|
49
|
-
src_files:
|
|
50
|
-
target_files: Optional[
|
|
54
|
+
data_config: DataConfig,
|
|
55
|
+
src_files: list[Path],
|
|
56
|
+
target_files: Optional[list[Path]] = None,
|
|
51
57
|
read_source_func: Callable = read_tiff,
|
|
52
58
|
) -> None:
|
|
59
|
+
"""Constructors.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
data_config : DataConfig
|
|
64
|
+
Data configuration.
|
|
65
|
+
src_files : list[Path]
|
|
66
|
+
List of data files.
|
|
67
|
+
target_files : list[Path] or None, optional
|
|
68
|
+
Optional list of target files, by default None.
|
|
69
|
+
read_source_func : Callable, optional
|
|
70
|
+
Read source function for custom types, by default read_tiff.
|
|
71
|
+
"""
|
|
53
72
|
self.data_config = data_config
|
|
54
73
|
self.data_files = src_files
|
|
55
74
|
self.target_files = target_files
|
|
56
|
-
self.data_config = data_config
|
|
57
75
|
self.read_source_func = read_source_func
|
|
58
76
|
|
|
59
77
|
# compute mean and std over the dataset
|
|
60
|
-
|
|
61
|
-
|
|
78
|
+
# only checking the image_mean because the DataConfig class ensures that
|
|
79
|
+
# if image_mean is provided, image_std is also provided
|
|
80
|
+
if not self.data_config.image_means:
|
|
81
|
+
self.data_stats = self._calculate_mean_and_std()
|
|
82
|
+
logger.info(
|
|
83
|
+
f"Computed dataset mean: {self.data_stats.image_stats.means},"
|
|
84
|
+
f"std: {self.data_stats.image_stats.stds}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# update the mean in the config
|
|
88
|
+
self.data_config.set_mean_and_std(
|
|
89
|
+
image_means=self.data_stats.image_stats.means,
|
|
90
|
+
image_stds=self.data_stats.image_stats.stds,
|
|
91
|
+
target_means=(
|
|
92
|
+
list(self.data_stats.target_stats.means)
|
|
93
|
+
if self.data_stats.target_stats.means is not None
|
|
94
|
+
else None
|
|
95
|
+
),
|
|
96
|
+
target_stds=(
|
|
97
|
+
list(self.data_stats.target_stats.stds)
|
|
98
|
+
if self.data_stats.target_stats.stds is not None
|
|
99
|
+
else None
|
|
100
|
+
),
|
|
101
|
+
)
|
|
62
102
|
|
|
63
|
-
# update mean and std in configuration
|
|
64
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
65
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
66
103
|
else:
|
|
67
|
-
|
|
68
|
-
self.
|
|
104
|
+
# if mean and std are provided in the config, use them
|
|
105
|
+
self.data_stats = StatsOutput(
|
|
106
|
+
Stats(self.data_config.image_means, self.data_config.image_stds),
|
|
107
|
+
Stats(self.data_config.target_means, self.data_config.target_stds),
|
|
108
|
+
)
|
|
69
109
|
|
|
70
|
-
#
|
|
71
|
-
self.patch_transform = Compose(
|
|
110
|
+
# create transform composed of normalization and other transforms
|
|
111
|
+
self.patch_transform = Compose(
|
|
112
|
+
transform_list=[
|
|
113
|
+
NormalizeModel(
|
|
114
|
+
image_means=self.data_stats.image_stats.means,
|
|
115
|
+
image_stds=self.data_stats.image_stats.stds,
|
|
116
|
+
target_means=self.data_stats.target_stats.means,
|
|
117
|
+
target_stds=self.data_stats.target_stats.stds,
|
|
118
|
+
)
|
|
119
|
+
]
|
|
120
|
+
+ data_config.transforms
|
|
121
|
+
)
|
|
72
122
|
|
|
73
|
-
def _calculate_mean_and_std(self) ->
|
|
123
|
+
def _calculate_mean_and_std(self) -> StatsOutput:
|
|
74
124
|
"""
|
|
75
125
|
Calculate mean and std of the dataset.
|
|
76
126
|
|
|
77
127
|
Returns
|
|
78
128
|
-------
|
|
79
|
-
|
|
80
|
-
|
|
129
|
+
PatchedOutput
|
|
130
|
+
Data class containing the image statistics.
|
|
81
131
|
"""
|
|
82
|
-
means, stds = 0, 0
|
|
83
132
|
num_samples = 0
|
|
133
|
+
image_stats = WelfordStatistics()
|
|
134
|
+
if self.target_files is not None:
|
|
135
|
+
target_stats = WelfordStatistics()
|
|
136
|
+
|
|
137
|
+
for sample, target in iterate_over_files(
|
|
138
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
139
|
+
):
|
|
140
|
+
# update the image statistics
|
|
141
|
+
image_stats.update(sample, num_samples)
|
|
142
|
+
|
|
143
|
+
# update the target statistics if target is available
|
|
144
|
+
if target is not None:
|
|
145
|
+
target_stats.update(target, num_samples)
|
|
84
146
|
|
|
85
|
-
for sample, _ in self._iterate_over_files():
|
|
86
|
-
means += sample.mean()
|
|
87
|
-
stds += sample.std()
|
|
88
147
|
num_samples += 1
|
|
89
148
|
|
|
90
149
|
if num_samples == 0:
|
|
91
150
|
raise ValueError("No samples found in the dataset.")
|
|
92
151
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
logger.info(f"Calculated mean and std for {num_samples} images")
|
|
97
|
-
logger.info(f"Mean: {result_mean}, std: {result_std}")
|
|
98
|
-
return result_mean, result_std
|
|
152
|
+
# Average the means and stds per sample
|
|
153
|
+
image_means, image_stds = image_stats.finalize()
|
|
99
154
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
103
|
-
"""
|
|
104
|
-
Iterate over data source and yield whole image.
|
|
155
|
+
if target is not None:
|
|
156
|
+
target_means, target_stds = target_stats.finalize()
|
|
105
157
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
worker_info = get_worker_info()
|
|
116
|
-
worker_id = worker_info.id if worker_info is not None else 0
|
|
117
|
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
118
|
-
|
|
119
|
-
# iterate over the files
|
|
120
|
-
for i, filename in enumerate(self.data_files):
|
|
121
|
-
# retrieve file corresponding to the worker id
|
|
122
|
-
if i % num_workers == worker_id:
|
|
123
|
-
try:
|
|
124
|
-
# read data
|
|
125
|
-
sample = self.read_source_func(filename, self.data_config.axes)
|
|
126
|
-
|
|
127
|
-
# read target, if available
|
|
128
|
-
if self.target_files is not None:
|
|
129
|
-
if filename.name != self.target_files[i].name:
|
|
130
|
-
raise ValueError(
|
|
131
|
-
f"File {filename} does not match target file "
|
|
132
|
-
f"{self.target_files[i]}. Have you passed sorted "
|
|
133
|
-
f"arrays?"
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# read target
|
|
137
|
-
target = self.read_source_func(
|
|
138
|
-
self.target_files[i], self.data_config.axes
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
yield sample, target
|
|
142
|
-
else:
|
|
143
|
-
yield sample, None
|
|
144
|
-
|
|
145
|
-
except Exception as e:
|
|
146
|
-
logger.error(f"Error reading file {filename}: {e}")
|
|
158
|
+
logger.info(f"Calculated mean and std for {num_samples} images")
|
|
159
|
+
logger.info(f"Mean: {image_means}, std: {image_stds}")
|
|
160
|
+
return StatsOutput(
|
|
161
|
+
Stats(image_means, image_stds),
|
|
162
|
+
Stats(
|
|
163
|
+
np.array(target_means) if target is not None else None,
|
|
164
|
+
np.array(target_stds) if target is not None else None,
|
|
165
|
+
),
|
|
166
|
+
)
|
|
147
167
|
|
|
148
168
|
def __iter__(
|
|
149
169
|
self,
|
|
150
|
-
) -> Generator[
|
|
170
|
+
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
|
151
171
|
"""
|
|
152
172
|
Iterate over data source and yield single patch.
|
|
153
173
|
|
|
@@ -157,22 +177,18 @@ class PathIterableDataset(IterableDataset):
|
|
|
157
177
|
Single patch.
|
|
158
178
|
"""
|
|
159
179
|
assert (
|
|
160
|
-
self.
|
|
180
|
+
self.data_stats.image_stats.means is not None
|
|
181
|
+
and self.data_stats.image_stats.stds is not None
|
|
161
182
|
), "Mean and std must be provided"
|
|
162
183
|
|
|
163
184
|
# iterate over files
|
|
164
|
-
for sample_input, sample_target in
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
None
|
|
168
|
-
if sample_target is None
|
|
169
|
-
else reshape_array(sample_target, self.data_config.axes)
|
|
170
|
-
)
|
|
171
|
-
|
|
185
|
+
for sample_input, sample_target in iterate_over_files(
|
|
186
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
187
|
+
):
|
|
172
188
|
patches = extract_patches_random(
|
|
173
|
-
arr=
|
|
189
|
+
arr=sample_input,
|
|
174
190
|
patch_size=self.data_config.patch_size,
|
|
175
|
-
target=
|
|
191
|
+
target=sample_target,
|
|
176
192
|
)
|
|
177
193
|
|
|
178
194
|
# iterate over patches
|
|
@@ -209,9 +225,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
209
225
|
Parameters
|
|
210
226
|
----------
|
|
211
227
|
percentage : float, optional
|
|
212
|
-
Percentage of files to split up, by default 0.1
|
|
228
|
+
Percentage of files to split up, by default 0.1.
|
|
213
229
|
minimum_number : int, optional
|
|
214
|
-
Minimum number of files to split up, by default 5
|
|
230
|
+
Minimum number of files to split up, by default 5.
|
|
215
231
|
|
|
216
232
|
Returns
|
|
217
233
|
-------
|
|
@@ -273,89 +289,3 @@ class PathIterableDataset(IterableDataset):
|
|
|
273
289
|
dataset.target_files = val_target_files
|
|
274
290
|
|
|
275
291
|
return dataset
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
class IterablePredictionDataset(PathIterableDataset):
|
|
279
|
-
"""
|
|
280
|
-
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
281
|
-
|
|
282
|
-
Parameters
|
|
283
|
-
----------
|
|
284
|
-
data_path : Union[str, Path]
|
|
285
|
-
Path to the data, must be a directory.
|
|
286
|
-
axes : str
|
|
287
|
-
Description of axes in format STCZYX.
|
|
288
|
-
mean : Optional[float], optional
|
|
289
|
-
Expected mean of the dataset, by default None.
|
|
290
|
-
std : Optional[float], optional
|
|
291
|
-
Expected standard deviation of the dataset, by default None.
|
|
292
|
-
patch_transform : Optional[Callable], optional
|
|
293
|
-
Patch transform callable, by default None.
|
|
294
|
-
"""
|
|
295
|
-
|
|
296
|
-
def __init__(
|
|
297
|
-
self,
|
|
298
|
-
prediction_config: InferenceConfig,
|
|
299
|
-
src_files: List[Path],
|
|
300
|
-
read_source_func: Callable = read_tiff,
|
|
301
|
-
**kwargs: Any,
|
|
302
|
-
) -> None:
|
|
303
|
-
super().__init__(
|
|
304
|
-
data_config=prediction_config,
|
|
305
|
-
src_files=src_files,
|
|
306
|
-
read_source_func=read_source_func,
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
self.prediction_config = prediction_config
|
|
310
|
-
self.axes = prediction_config.axes
|
|
311
|
-
self.tile_size = self.prediction_config.tile_size
|
|
312
|
-
self.tile_overlap = self.prediction_config.tile_overlap
|
|
313
|
-
self.read_source_func = read_source_func
|
|
314
|
-
|
|
315
|
-
# tile only if both tile size and overlaps are provided
|
|
316
|
-
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
317
|
-
|
|
318
|
-
# get tta transforms
|
|
319
|
-
self.patch_transform = Compose(
|
|
320
|
-
transform_list=prediction_config.transforms,
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
def __iter__(
|
|
324
|
-
self,
|
|
325
|
-
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
326
|
-
"""
|
|
327
|
-
Iterate over data source and yield single patch.
|
|
328
|
-
|
|
329
|
-
Yields
|
|
330
|
-
------
|
|
331
|
-
np.ndarray
|
|
332
|
-
Single patch.
|
|
333
|
-
"""
|
|
334
|
-
assert (
|
|
335
|
-
self.mean is not None and self.std is not None
|
|
336
|
-
), "Mean and std must be provided"
|
|
337
|
-
|
|
338
|
-
for sample, _ in self._iterate_over_files():
|
|
339
|
-
# reshape array
|
|
340
|
-
reshaped_sample = reshape_array(sample, self.axes)
|
|
341
|
-
|
|
342
|
-
if self.tile:
|
|
343
|
-
# generate patches, return a generator
|
|
344
|
-
patch_gen = extract_tiles(
|
|
345
|
-
arr=reshaped_sample,
|
|
346
|
-
tile_size=self.tile_size,
|
|
347
|
-
overlaps=self.tile_overlap,
|
|
348
|
-
)
|
|
349
|
-
else:
|
|
350
|
-
# just wrap the sample in a generator with default tiling info
|
|
351
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
352
|
-
patch_gen = (
|
|
353
|
-
(reshaped_sample, TileInformation(array_shape=array_shape))
|
|
354
|
-
for _ in range(1)
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
# apply transform to patches
|
|
358
|
-
for patch_array, tile_info in patch_gen:
|
|
359
|
-
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
360
|
-
|
|
361
|
-
yield transformed_patch, tile_info
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Iterable prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.transforms import Compose
|
|
12
|
+
|
|
13
|
+
from ..config import InferenceConfig
|
|
14
|
+
from ..config.transformations import NormalizeModel
|
|
15
|
+
from .dataset_utils import iterate_over_files, read_tiff
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class IterablePredDataset(IterableDataset):
|
|
19
|
+
"""Simple iterable prediction dataset.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
prediction_config : InferenceConfig
|
|
24
|
+
Inference configuration.
|
|
25
|
+
src_files : List[Path]
|
|
26
|
+
List of data files.
|
|
27
|
+
read_source_func : Callable, optional
|
|
28
|
+
Read source function for custom types, by default read_tiff.
|
|
29
|
+
**kwargs : Any
|
|
30
|
+
Additional keyword arguments, unused.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
data_path : Union[str, Path]
|
|
35
|
+
Path to the data, must be a directory.
|
|
36
|
+
axes : str
|
|
37
|
+
Description of axes in format STCZYX.
|
|
38
|
+
mean : Optional[float], optional
|
|
39
|
+
Expected mean of the dataset, by default None.
|
|
40
|
+
std : Optional[float], optional
|
|
41
|
+
Expected standard deviation of the dataset, by default None.
|
|
42
|
+
patch_transform : Optional[Callable], optional
|
|
43
|
+
Patch transform callable, by default None.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
prediction_config: InferenceConfig,
|
|
49
|
+
src_files: list[Path],
|
|
50
|
+
read_source_func: Callable = read_tiff,
|
|
51
|
+
**kwargs: Any,
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Constructor.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
prediction_config : InferenceConfig
|
|
58
|
+
Inference configuration.
|
|
59
|
+
src_files : list of pathlib.Path
|
|
60
|
+
List of data files.
|
|
61
|
+
read_source_func : Callable, optional
|
|
62
|
+
Read source function for custom types, by default read_tiff.
|
|
63
|
+
**kwargs : Any
|
|
64
|
+
Additional keyword arguments, unused.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
ValueError
|
|
69
|
+
If mean and std are not provided in the inference configuration.
|
|
70
|
+
"""
|
|
71
|
+
self.prediction_config = prediction_config
|
|
72
|
+
self.data_files = src_files
|
|
73
|
+
self.axes = prediction_config.axes
|
|
74
|
+
self.read_source_func = read_source_func
|
|
75
|
+
|
|
76
|
+
# check mean and std and create normalize transform
|
|
77
|
+
if (
|
|
78
|
+
self.prediction_config.image_means is None
|
|
79
|
+
or self.prediction_config.image_stds is None
|
|
80
|
+
):
|
|
81
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
82
|
+
else:
|
|
83
|
+
self.image_means = self.prediction_config.image_means
|
|
84
|
+
self.image_stds = self.prediction_config.image_stds
|
|
85
|
+
|
|
86
|
+
# instantiate normalize transform
|
|
87
|
+
self.patch_transform = Compose(
|
|
88
|
+
transform_list=[
|
|
89
|
+
NormalizeModel(
|
|
90
|
+
image_means=self.image_means,
|
|
91
|
+
image_stds=self.image_stds,
|
|
92
|
+
)
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def __iter__(
|
|
97
|
+
self,
|
|
98
|
+
) -> Generator[NDArray, None, None]:
|
|
99
|
+
"""
|
|
100
|
+
Iterate over data source and yield single patch.
|
|
101
|
+
|
|
102
|
+
Yields
|
|
103
|
+
------
|
|
104
|
+
NDArray
|
|
105
|
+
Single patch.
|
|
106
|
+
"""
|
|
107
|
+
assert (
|
|
108
|
+
self.image_means is not None and self.image_stds is not None
|
|
109
|
+
), "Mean and std must be provided"
|
|
110
|
+
|
|
111
|
+
for sample, _ in iterate_over_files(
|
|
112
|
+
self.prediction_config,
|
|
113
|
+
self.data_files,
|
|
114
|
+
read_source_func=self.read_source_func,
|
|
115
|
+
):
|
|
116
|
+
# sample has S dimension
|
|
117
|
+
for i in range(sample.shape[0]):
|
|
118
|
+
|
|
119
|
+
transformed_sample, _ = self.patch_transform(patch=sample[i])
|
|
120
|
+
|
|
121
|
+
yield transformed_sample
|