careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset preparation module.
|
|
3
|
+
|
|
4
|
+
Methods to set up the datasets for training, validation and prediction.
|
|
5
|
+
"""
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Optional, Union
|
|
8
|
+
|
|
9
|
+
from careamics.config import Configuration
|
|
10
|
+
from careamics.manipulation import default_manipulate
|
|
11
|
+
from careamics.utils import check_tiling_validity
|
|
12
|
+
|
|
13
|
+
from .extraction_strategy import ExtractionStrategy
|
|
14
|
+
from .in_memory_dataset import InMemoryDataset
|
|
15
|
+
from .tiff_dataset import TiffDataset
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_train_dataset(
|
|
19
|
+
config: Configuration, train_path: str
|
|
20
|
+
) -> Union[TiffDataset, InMemoryDataset]:
|
|
21
|
+
"""
|
|
22
|
+
Create training dataset.
|
|
23
|
+
|
|
24
|
+
Depending on the configuration, this methods return either a TiffDataset or an
|
|
25
|
+
InMemoryDataset.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
config : Configuration
|
|
30
|
+
Configuration.
|
|
31
|
+
train_path : Union[str, Path]
|
|
32
|
+
Path to training data.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Union[TiffDataset, InMemoryDataset]
|
|
37
|
+
Dataset.
|
|
38
|
+
"""
|
|
39
|
+
if config.data.in_memory:
|
|
40
|
+
dataset = InMemoryDataset(
|
|
41
|
+
data_path=train_path,
|
|
42
|
+
data_format=config.data.data_format,
|
|
43
|
+
axes=config.data.axes,
|
|
44
|
+
mean=config.data.mean,
|
|
45
|
+
std=config.data.std,
|
|
46
|
+
patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
|
|
47
|
+
patch_size=config.training.patch_size,
|
|
48
|
+
patch_transform=default_manipulate,
|
|
49
|
+
patch_transform_params={
|
|
50
|
+
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
|
|
51
|
+
"roi_size": config.algorithm.roi_size,
|
|
52
|
+
},
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
dataset = TiffDataset(
|
|
56
|
+
data_path=train_path,
|
|
57
|
+
data_format=config.data.data_format,
|
|
58
|
+
axes=config.data.axes,
|
|
59
|
+
mean=config.data.mean,
|
|
60
|
+
std=config.data.std,
|
|
61
|
+
patch_extraction_method=ExtractionStrategy.RANDOM,
|
|
62
|
+
patch_size=config.training.patch_size,
|
|
63
|
+
patch_transform=default_manipulate,
|
|
64
|
+
patch_transform_params={
|
|
65
|
+
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
|
|
66
|
+
"roi_size": config.algorithm.roi_size,
|
|
67
|
+
},
|
|
68
|
+
)
|
|
69
|
+
return dataset
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset:
|
|
73
|
+
"""
|
|
74
|
+
Create validation dataset.
|
|
75
|
+
|
|
76
|
+
Validation dataset is kept in memory.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
config : Configuration
|
|
81
|
+
Configuration.
|
|
82
|
+
val_path : Union[str, Path]
|
|
83
|
+
Path to validation data.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
TiffDataset
|
|
88
|
+
In memory dataset.
|
|
89
|
+
"""
|
|
90
|
+
data_path = val_path
|
|
91
|
+
|
|
92
|
+
dataset = InMemoryDataset(
|
|
93
|
+
data_path=data_path,
|
|
94
|
+
data_format=config.data.data_format,
|
|
95
|
+
axes=config.data.axes,
|
|
96
|
+
mean=config.data.mean,
|
|
97
|
+
std=config.data.std,
|
|
98
|
+
patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
|
|
99
|
+
patch_size=config.training.patch_size,
|
|
100
|
+
patch_transform=default_manipulate,
|
|
101
|
+
patch_transform_params={
|
|
102
|
+
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return dataset
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_prediction_dataset(
|
|
110
|
+
config: Configuration,
|
|
111
|
+
pred_path: Union[str, Path],
|
|
112
|
+
*,
|
|
113
|
+
tile_shape: Optional[List[int]] = None,
|
|
114
|
+
overlaps: Optional[List[int]] = None,
|
|
115
|
+
axes: Optional[str] = None,
|
|
116
|
+
) -> TiffDataset:
|
|
117
|
+
"""
|
|
118
|
+
Create prediction dataset.
|
|
119
|
+
|
|
120
|
+
To use tiling, both `tile_shape` and `overlaps` must be specified, have same
|
|
121
|
+
length, be divisible by 2 and greater than 0. Finally, the overlaps must be
|
|
122
|
+
smaller than the tiles.
|
|
123
|
+
|
|
124
|
+
By default, axes are extracted from the configuration. To use images with
|
|
125
|
+
different axes, set the `axes` parameter. Note that the difference between
|
|
126
|
+
configuration and parameter axes must be S or T, but not any of the spatial
|
|
127
|
+
dimensions (e.g. 2D vs 3D).
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
config : Configuration
|
|
132
|
+
Configuration.
|
|
133
|
+
pred_path : Union[str, Path]
|
|
134
|
+
Path to prediction data.
|
|
135
|
+
tile_shape : Optional[List[int]], optional
|
|
136
|
+
2D or 3D shape of the tiles, by default None.
|
|
137
|
+
overlaps : Optional[List[int]], optional
|
|
138
|
+
2D or 3D overlaps between tiles, by default None.
|
|
139
|
+
axes : Optional[str], optional
|
|
140
|
+
Axes of the data, by default None.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
TiffDataset
|
|
145
|
+
Dataset.
|
|
146
|
+
"""
|
|
147
|
+
use_tiling = False # default value
|
|
148
|
+
|
|
149
|
+
# Validate tiles and overlaps
|
|
150
|
+
if tile_shape is not None and overlaps is not None:
|
|
151
|
+
check_tiling_validity(tile_shape, overlaps)
|
|
152
|
+
|
|
153
|
+
# Use tiling
|
|
154
|
+
use_tiling = True
|
|
155
|
+
|
|
156
|
+
# Extraction strategy
|
|
157
|
+
if use_tiling:
|
|
158
|
+
patch_extraction_method = ExtractionStrategy.TILED
|
|
159
|
+
else:
|
|
160
|
+
patch_extraction_method = None
|
|
161
|
+
|
|
162
|
+
# Create dataset
|
|
163
|
+
dataset = TiffDataset(
|
|
164
|
+
data_path=pred_path,
|
|
165
|
+
data_format=config.data.data_format,
|
|
166
|
+
axes=config.data.axes if axes is None else axes, # supersede axes
|
|
167
|
+
mean=config.data.mean,
|
|
168
|
+
std=config.data.std,
|
|
169
|
+
patch_size=tile_shape,
|
|
170
|
+
patch_overlap=overlaps,
|
|
171
|
+
patch_extraction_method=patch_extraction_method,
|
|
172
|
+
patch_transform=None,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return dataset
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tiff dataset module.
|
|
3
|
+
|
|
4
|
+
This module contains the implementation of the TiffDataset class, which allows loading
|
|
5
|
+
tiff files.
|
|
6
|
+
"""
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from careamics.utils import normalize
|
|
14
|
+
from careamics.utils.logging import get_logger
|
|
15
|
+
|
|
16
|
+
from .dataset_utils import (
|
|
17
|
+
list_files,
|
|
18
|
+
read_tiff,
|
|
19
|
+
)
|
|
20
|
+
from .extraction_strategy import ExtractionStrategy
|
|
21
|
+
from .patching import generate_patches
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TiffDataset(torch.utils.data.IterableDataset):
|
|
27
|
+
"""
|
|
28
|
+
Dataset allowing extracting patches from tiff images.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
data_path : Union[str, Path]
|
|
33
|
+
Path to the data, must be a directory.
|
|
34
|
+
data_format : str
|
|
35
|
+
Extension of the files to load, without the period.
|
|
36
|
+
axes : str
|
|
37
|
+
Description of axes in format STCZYX.
|
|
38
|
+
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
39
|
+
Patch extraction strategy, as defined in extraction_strategy.
|
|
40
|
+
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
41
|
+
Size of the patches in each dimension, by default None.
|
|
42
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
43
|
+
Overlap of the patches in each dimension, by default None.
|
|
44
|
+
mean : Optional[float], optional
|
|
45
|
+
Expected mean of the dataset, by default None.
|
|
46
|
+
std : Optional[float], optional
|
|
47
|
+
Expected standard deviation of the dataset, by default None.
|
|
48
|
+
patch_transform : Optional[Callable], optional
|
|
49
|
+
Patch transform callable, by default None.
|
|
50
|
+
patch_transform_params : Optional[Dict], optional
|
|
51
|
+
Patch transform parameters, by default None.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
data_path: Union[str, Path],
|
|
57
|
+
data_format: str, # TODO: TiffDataset should not know that they are tiff
|
|
58
|
+
axes: str,
|
|
59
|
+
patch_extraction_method: Union[ExtractionStrategy, None],
|
|
60
|
+
patch_size: Optional[Union[List[int], Tuple[int]]] = None,
|
|
61
|
+
patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
|
|
62
|
+
mean: Optional[float] = None,
|
|
63
|
+
std: Optional[float] = None,
|
|
64
|
+
patch_transform: Optional[Callable] = None,
|
|
65
|
+
patch_transform_params: Optional[Dict] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Constructor.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
data_path : Union[str, Path]
|
|
73
|
+
Path to the data, must be a directory.
|
|
74
|
+
data_format : str
|
|
75
|
+
Extension of the files to load, without the period.
|
|
76
|
+
axes : str
|
|
77
|
+
Description of axes in format STCZYX.
|
|
78
|
+
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
79
|
+
Patch extraction strategy, as defined in extraction_strategy.
|
|
80
|
+
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
81
|
+
Size of the patches in each dimension, by default None.
|
|
82
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
83
|
+
Overlap of the patches in each dimension, by default None.
|
|
84
|
+
mean : Optional[float], optional
|
|
85
|
+
Mean of the dataset, by default None.
|
|
86
|
+
std : Optional[float], optional
|
|
87
|
+
Standard deviation of the dataset, by default None.
|
|
88
|
+
patch_transform : Optional[Callable], optional
|
|
89
|
+
Patch transform callable, by default None.
|
|
90
|
+
patch_transform_params : Optional[Dict], optional
|
|
91
|
+
Patch transform parameters, by default None.
|
|
92
|
+
|
|
93
|
+
Raises
|
|
94
|
+
------
|
|
95
|
+
ValueError
|
|
96
|
+
If data_path is not a directory.
|
|
97
|
+
"""
|
|
98
|
+
self.data_path = Path(data_path)
|
|
99
|
+
if not self.data_path.is_dir():
|
|
100
|
+
raise ValueError("Path to data should be an existing folder.")
|
|
101
|
+
|
|
102
|
+
self.data_format = data_format
|
|
103
|
+
self.axes = axes
|
|
104
|
+
|
|
105
|
+
self.patch_transform = patch_transform
|
|
106
|
+
|
|
107
|
+
self.files = list_files(self.data_path, self.data_format)
|
|
108
|
+
|
|
109
|
+
self.mean = mean
|
|
110
|
+
self.std = std
|
|
111
|
+
if not mean or not std:
|
|
112
|
+
self.mean, self.std = self._calculate_mean_and_std()
|
|
113
|
+
|
|
114
|
+
self.patch_size = patch_size
|
|
115
|
+
self.patch_overlap = patch_overlap
|
|
116
|
+
self.patch_extraction_method = patch_extraction_method
|
|
117
|
+
self.patch_transform = patch_transform
|
|
118
|
+
self.patch_transform_params = patch_transform_params
|
|
119
|
+
|
|
120
|
+
def _calculate_mean_and_std(self) -> Tuple[float, float]:
|
|
121
|
+
"""
|
|
122
|
+
Calculate mean and std of the dataset.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
Tuple[float, float]
|
|
127
|
+
Tuple containing mean and standard deviation.
|
|
128
|
+
"""
|
|
129
|
+
means, stds = 0, 0
|
|
130
|
+
num_samples = 0
|
|
131
|
+
|
|
132
|
+
for sample in self._iterate_files():
|
|
133
|
+
means += sample.mean()
|
|
134
|
+
stds += np.std(sample)
|
|
135
|
+
num_samples += 1
|
|
136
|
+
|
|
137
|
+
result_mean = means / num_samples
|
|
138
|
+
result_std = stds / num_samples
|
|
139
|
+
|
|
140
|
+
logger.info(f"Calculated mean and std for {num_samples} images")
|
|
141
|
+
logger.info(f"Mean: {result_mean}, std: {result_std}")
|
|
142
|
+
return result_mean, result_std
|
|
143
|
+
|
|
144
|
+
def _iterate_files(self) -> Generator:
|
|
145
|
+
"""
|
|
146
|
+
Iterate over data source and yield whole image.
|
|
147
|
+
|
|
148
|
+
Yields
|
|
149
|
+
------
|
|
150
|
+
np.ndarray
|
|
151
|
+
Image.
|
|
152
|
+
"""
|
|
153
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
154
|
+
# dataset object
|
|
155
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
156
|
+
# from the workers
|
|
157
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
158
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
159
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
160
|
+
|
|
161
|
+
for i, filename in enumerate(self.files):
|
|
162
|
+
if i % num_workers == worker_id:
|
|
163
|
+
sample = read_tiff(filename, self.axes)
|
|
164
|
+
yield sample
|
|
165
|
+
|
|
166
|
+
def __iter__(self) -> Generator[np.ndarray, None, None]:
|
|
167
|
+
"""
|
|
168
|
+
Iterate over data source and yield single patch.
|
|
169
|
+
|
|
170
|
+
Yields
|
|
171
|
+
------
|
|
172
|
+
np.ndarray
|
|
173
|
+
Single patch.
|
|
174
|
+
"""
|
|
175
|
+
assert (
|
|
176
|
+
self.mean is not None and self.std is not None
|
|
177
|
+
), "Mean and std must be provided"
|
|
178
|
+
for sample in self._iterate_files():
|
|
179
|
+
# TODO patch_extraction_method should never be None!
|
|
180
|
+
if self.patch_extraction_method:
|
|
181
|
+
# TODO: move S and T unpacking logic from patch generator
|
|
182
|
+
patches = generate_patches(
|
|
183
|
+
sample,
|
|
184
|
+
self.patch_extraction_method,
|
|
185
|
+
self.patch_size,
|
|
186
|
+
self.patch_overlap,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
for patch in patches:
|
|
190
|
+
if isinstance(patch, tuple):
|
|
191
|
+
normalized_patch = normalize(
|
|
192
|
+
img=patch[0], mean=self.mean, std=self.std
|
|
193
|
+
)
|
|
194
|
+
patch = (normalized_patch, *patch[1:])
|
|
195
|
+
else:
|
|
196
|
+
patch = normalize(img=patch, mean=self.mean, std=self.std)
|
|
197
|
+
|
|
198
|
+
if self.patch_transform is not None:
|
|
199
|
+
assert self.patch_transform_params is not None
|
|
200
|
+
patch = self.patch_transform(
|
|
201
|
+
patch, **self.patch_transform_params
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
yield patch
|
|
205
|
+
|
|
206
|
+
else:
|
|
207
|
+
# if S or T dims are not empty - assume every image is a separate
|
|
208
|
+
# sample in dim 0
|
|
209
|
+
for i in range(sample.shape[0]):
|
|
210
|
+
item = np.expand_dims(sample[i], (0, 1))
|
|
211
|
+
item = normalize(img=item, mean=self.mean, std=self.std)
|
|
212
|
+
yield item
|