careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
from albumentations import Compose
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
|
|
9
|
+
from careamics.config import DataModel
|
|
10
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
11
|
+
from careamics.config.support import SupportedData
|
|
12
|
+
from careamics.dataset.dataset_utils import (
|
|
13
|
+
get_files_size,
|
|
14
|
+
get_read_func,
|
|
15
|
+
list_files,
|
|
16
|
+
validate_source_target_files,
|
|
17
|
+
)
|
|
18
|
+
from careamics.dataset.in_memory_dataset import (
|
|
19
|
+
InMemoryDataset,
|
|
20
|
+
)
|
|
21
|
+
from careamics.dataset.iterable_dataset import (
|
|
22
|
+
PathIterableDataset,
|
|
23
|
+
)
|
|
24
|
+
from careamics.utils import get_logger, get_ram_size
|
|
25
|
+
|
|
26
|
+
DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CAREamicsWood(L.LightningDataModule):
|
|
32
|
+
"""
|
|
33
|
+
LightningDataModule for training and validation datasets.
|
|
34
|
+
|
|
35
|
+
The data module can be used with Path, str or numpy arrays. In the case of
|
|
36
|
+
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
37
|
+
inputs, it calculates the total file size and estimate whether it can fit in
|
|
38
|
+
memory. If it does not, it iterates through the files. This behaviour can be
|
|
39
|
+
deactivated by setting `use_in_memory` to False, in which case it will
|
|
40
|
+
always use the iterating dataset to train on a Path or str.
|
|
41
|
+
|
|
42
|
+
The data can be either a folder containing images or a single file.
|
|
43
|
+
|
|
44
|
+
Validation can be omitted, in which case the validation data is extracted from
|
|
45
|
+
the training data. The percentage of the training data to use for validation,
|
|
46
|
+
as well as the minimum number of patches or files to split from the training
|
|
47
|
+
data can be set using `val_percentage` and `val_minimum_split`, respectively.
|
|
48
|
+
|
|
49
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
50
|
+
and provide a function that returns a numpy array from a path as
|
|
51
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
52
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
53
|
+
|
|
54
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
55
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
data_config: DataModel,
|
|
61
|
+
train_data: Union[Path, str, np.ndarray],
|
|
62
|
+
val_data: Optional[Union[Path, str, np.ndarray]] = None,
|
|
63
|
+
train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
64
|
+
val_data_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
65
|
+
read_source_func: Optional[Callable] = None,
|
|
66
|
+
extension_filter: str = "",
|
|
67
|
+
val_percentage: float = 0.1,
|
|
68
|
+
val_minimum_split: int = 5,
|
|
69
|
+
use_in_memory: bool = True,
|
|
70
|
+
) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Constructor.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
data_config : DataModel
|
|
77
|
+
Pydantic model for CAREamics data configuration.
|
|
78
|
+
train_data : Union[Path, str, np.ndarray]
|
|
79
|
+
Training data, can be a path to a folder, a file or a numpy array.
|
|
80
|
+
val_data : Optional[Union[Path, str, np.ndarray]], optional
|
|
81
|
+
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
82
|
+
default None.
|
|
83
|
+
train_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
84
|
+
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
85
|
+
default None.
|
|
86
|
+
val_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
87
|
+
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
88
|
+
by default None.
|
|
89
|
+
read_source_func : Optional[Callable], optional
|
|
90
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
91
|
+
data type (see DataModel).
|
|
92
|
+
extension_filter : str, optional
|
|
93
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
94
|
+
(see DataModel).
|
|
95
|
+
val_percentage : float, optional
|
|
96
|
+
Percentage of the training data to use for validation, by default 0.1. Only
|
|
97
|
+
used if `val_data` is None.
|
|
98
|
+
val_minimum_split : int, optional
|
|
99
|
+
Minimum number of patches or files to split from the training data for
|
|
100
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
101
|
+
|
|
102
|
+
Raises
|
|
103
|
+
------
|
|
104
|
+
NotImplementedError
|
|
105
|
+
Raised if target data is provided.
|
|
106
|
+
ValueError
|
|
107
|
+
If the input types are mixed (e.g. Path and np.ndarray).
|
|
108
|
+
ValueError
|
|
109
|
+
If the data type is `custom` and no `read_source_func` is provided.
|
|
110
|
+
ValueError
|
|
111
|
+
If the data type is `array` and the input is not a numpy array.
|
|
112
|
+
ValueError
|
|
113
|
+
If the data type is `tiff` and the input is neither a Path nor a str.
|
|
114
|
+
"""
|
|
115
|
+
super().__init__()
|
|
116
|
+
|
|
117
|
+
# check input types coherence (no mixed types)
|
|
118
|
+
inputs = [train_data, val_data, train_data_target, val_data_target]
|
|
119
|
+
types_set = {type(i) for i in inputs}
|
|
120
|
+
if len(types_set) > 2: # None + expected type
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Inputs for `train_data`, `val_data`, `train_data_target` and "
|
|
123
|
+
f"`val_data_target` must be of the same type or None. Got "
|
|
124
|
+
f"{types_set}."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# check that a read source function is provided for custom types
|
|
128
|
+
if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
131
|
+
f"specifying a `read_source_func`."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# and that arrays are passed, if array type specified
|
|
135
|
+
elif data_config.data_type == SupportedData.ARRAY and not isinstance(
|
|
136
|
+
train_data, np.ndarray
|
|
137
|
+
):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Expected array input (see configuration.data.data_type), but got "
|
|
140
|
+
f"{type(train_data)} instead."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# and that Path or str are passed, if tiff file type specified
|
|
144
|
+
elif data_config.data_type == SupportedData.TIFF and (
|
|
145
|
+
not isinstance(train_data, Path) and not isinstance(train_data, str)
|
|
146
|
+
):
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Expected Path or str input (see configuration.data.data_type), "
|
|
149
|
+
f"but got {type(train_data)} instead."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# configuration
|
|
153
|
+
self.data_config = data_config
|
|
154
|
+
self.data_type = data_config.data_type
|
|
155
|
+
self.batch_size = data_config.batch_size
|
|
156
|
+
self.use_in_memory = use_in_memory
|
|
157
|
+
|
|
158
|
+
# data
|
|
159
|
+
self.train_data = train_data
|
|
160
|
+
self.val_data = val_data
|
|
161
|
+
|
|
162
|
+
self.train_data_target = train_data_target
|
|
163
|
+
self.val_data_target = val_data_target
|
|
164
|
+
self.val_percentage = val_percentage
|
|
165
|
+
self.val_minimum_split = val_minimum_split
|
|
166
|
+
|
|
167
|
+
# read source function corresponding to the requested type
|
|
168
|
+
if data_config.data_type == SupportedData.CUSTOM:
|
|
169
|
+
# mypy check
|
|
170
|
+
assert read_source_func is not None
|
|
171
|
+
|
|
172
|
+
self.read_source_func: Callable = read_source_func
|
|
173
|
+
|
|
174
|
+
elif data_config.data_type != SupportedData.ARRAY:
|
|
175
|
+
self.read_source_func = get_read_func(data_config.data_type)
|
|
176
|
+
|
|
177
|
+
self.extension_filter = extension_filter
|
|
178
|
+
|
|
179
|
+
# Pytorch dataloader parameters
|
|
180
|
+
self.dataloader_params = (
|
|
181
|
+
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def prepare_data(self) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Hook used to prepare the data before calling `setup`.
|
|
187
|
+
|
|
188
|
+
Here, we only need to examine the data if it was provided as a str or a Path.
|
|
189
|
+
|
|
190
|
+
TODO: from lightning doc:
|
|
191
|
+
prepare_data is called from the main process. It is not recommended to assign
|
|
192
|
+
state here (e.g. self.x = y) since it is called on a single process and if you
|
|
193
|
+
assign states here then they won't be available for other processes.
|
|
194
|
+
|
|
195
|
+
https://lightning.ai/docs/pytorch/stable/data/datamodule.html
|
|
196
|
+
"""
|
|
197
|
+
# if the data is a Path or a str
|
|
198
|
+
if (
|
|
199
|
+
not isinstance(self.train_data, np.ndarray)
|
|
200
|
+
and not isinstance(self.val_data, np.ndarray)
|
|
201
|
+
and not isinstance(self.train_data_target, np.ndarray)
|
|
202
|
+
and not isinstance(self.val_data_target, np.ndarray)
|
|
203
|
+
):
|
|
204
|
+
# list training files
|
|
205
|
+
self.train_files = list_files(
|
|
206
|
+
self.train_data, self.data_type, self.extension_filter
|
|
207
|
+
)
|
|
208
|
+
self.train_files_size = get_files_size(self.train_files)
|
|
209
|
+
|
|
210
|
+
# list validation files
|
|
211
|
+
if self.val_data is not None:
|
|
212
|
+
self.val_files = list_files(
|
|
213
|
+
self.val_data, self.data_type, self.extension_filter
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# same for target data
|
|
217
|
+
if self.train_data_target is not None:
|
|
218
|
+
self.train_target_files: List[Path] = list_files(
|
|
219
|
+
self.train_data_target, self.data_type, self.extension_filter
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# verify that they match the training data
|
|
223
|
+
validate_source_target_files(self.train_files, self.train_target_files)
|
|
224
|
+
|
|
225
|
+
if self.val_data_target is not None:
|
|
226
|
+
self.val_target_files = list_files(
|
|
227
|
+
self.val_data_target, self.data_type, self.extension_filter
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# verify that they match the validation data
|
|
231
|
+
validate_source_target_files(self.val_files, self.val_target_files)
|
|
232
|
+
|
|
233
|
+
def setup(self, *args: Any, **kwargs: Any) -> None:
|
|
234
|
+
"""Hook called at the beginning of fit, validate, or predict."""
|
|
235
|
+
# if numpy array
|
|
236
|
+
if self.data_type == SupportedData.ARRAY:
|
|
237
|
+
# train dataset
|
|
238
|
+
self.train_dataset: DatasetType = InMemoryDataset(
|
|
239
|
+
data_config=self.data_config,
|
|
240
|
+
inputs=self.train_data,
|
|
241
|
+
data_target=self.train_data_target,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# validation dataset
|
|
245
|
+
if self.val_data is not None:
|
|
246
|
+
# create its own dataset
|
|
247
|
+
self.val_dataset: DatasetType = InMemoryDataset(
|
|
248
|
+
data_config=self.data_config,
|
|
249
|
+
inputs=self.val_data,
|
|
250
|
+
data_target=self.val_data_target,
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
# extract validation from the training patches
|
|
254
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
255
|
+
percentage=self.val_percentage,
|
|
256
|
+
minimum_patches=self.val_minimum_split,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# else we read files
|
|
260
|
+
else:
|
|
261
|
+
# Heuristics, if the file size is smaller than 80% of the RAM,
|
|
262
|
+
# we run the training in memory, otherwise we switch to iterable dataset
|
|
263
|
+
# The switch is deactivated if use_in_memory is False
|
|
264
|
+
if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
|
|
265
|
+
# train dataset
|
|
266
|
+
self.train_dataset = InMemoryDataset(
|
|
267
|
+
data_config=self.data_config,
|
|
268
|
+
inputs=self.train_files,
|
|
269
|
+
data_target=self.train_target_files
|
|
270
|
+
if self.train_data_target
|
|
271
|
+
else None,
|
|
272
|
+
read_source_func=self.read_source_func,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# validation dataset
|
|
276
|
+
if self.val_data is not None:
|
|
277
|
+
self.val_dataset = InMemoryDataset(
|
|
278
|
+
data_config=self.data_config,
|
|
279
|
+
inputs=self.val_files,
|
|
280
|
+
data_target=self.val_target_files
|
|
281
|
+
if self.val_data_target
|
|
282
|
+
else None,
|
|
283
|
+
read_source_func=self.read_source_func,
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
# split dataset
|
|
287
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
288
|
+
percentage=self.val_percentage,
|
|
289
|
+
minimum_patches=self.val_minimum_split,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# else if the data is too large, load file by file during training
|
|
293
|
+
else:
|
|
294
|
+
# create training dataset
|
|
295
|
+
self.train_dataset = PathIterableDataset(
|
|
296
|
+
data_config=self.data_config,
|
|
297
|
+
src_files=self.train_files,
|
|
298
|
+
target_files=self.train_target_files
|
|
299
|
+
if self.train_data_target
|
|
300
|
+
else None,
|
|
301
|
+
read_source_func=self.read_source_func,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# create validation dataset
|
|
305
|
+
if self.val_files is not None:
|
|
306
|
+
# create its own dataset
|
|
307
|
+
self.val_dataset = PathIterableDataset(
|
|
308
|
+
data_config=self.data_config,
|
|
309
|
+
src_files=self.val_files,
|
|
310
|
+
target_files=self.val_target_files
|
|
311
|
+
if self.val_data_target
|
|
312
|
+
else None,
|
|
313
|
+
read_source_func=self.read_source_func,
|
|
314
|
+
)
|
|
315
|
+
elif len(self.train_files) <= self.val_minimum_split:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Not enough files to split a minimum of "
|
|
318
|
+
f"{self.val_minimum_split} files, got {len(self.train_files)} "
|
|
319
|
+
f"files."
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
# extract validation from the training patches
|
|
323
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
324
|
+
percentage=self.val_percentage,
|
|
325
|
+
minimum_files=self.val_minimum_split,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def train_dataloader(self) -> Any:
|
|
329
|
+
"""
|
|
330
|
+
Create a dataloader for training.
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
Any
|
|
335
|
+
Training dataloader.
|
|
336
|
+
"""
|
|
337
|
+
return DataLoader(
|
|
338
|
+
self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def val_dataloader(self) -> Any:
|
|
342
|
+
"""
|
|
343
|
+
Create a dataloader for validation.
|
|
344
|
+
|
|
345
|
+
Returns
|
|
346
|
+
-------
|
|
347
|
+
Any
|
|
348
|
+
Validation dataloader.
|
|
349
|
+
"""
|
|
350
|
+
return DataLoader(
|
|
351
|
+
self.val_dataset,
|
|
352
|
+
batch_size=self.batch_size,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class CAREamicsTrainDataModule(CAREamicsWood):
|
|
357
|
+
"""
|
|
358
|
+
LightningDataModule wrapper for training and validation datasets.
|
|
359
|
+
|
|
360
|
+
Since the lightning datamodule has no access to the model, make sure that the
|
|
361
|
+
parameters passed to the datamodule are consistent with the model's requirements and
|
|
362
|
+
are coherent.
|
|
363
|
+
|
|
364
|
+
The data module can be used with Path, str or numpy arrays. In the case of
|
|
365
|
+
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
366
|
+
inputs, it calculates the total file size and estimate whether it can fit in
|
|
367
|
+
memory. If it does not, it iterates through the files. This behaviour can be
|
|
368
|
+
deactivated by setting `use_in_memory` to False, in which case it will
|
|
369
|
+
always use the iterating dataset to train on a Path or str.
|
|
370
|
+
|
|
371
|
+
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
372
|
+
`train_data`.
|
|
373
|
+
|
|
374
|
+
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
375
|
+
not compatible with supervised training. The default transformations applied to the
|
|
376
|
+
training patches are defined in `careamics.config.data_model`. To use different
|
|
377
|
+
transformations, pass a list of transforms or an albumentation `Compose` as
|
|
378
|
+
`transforms` parameter. See examples for more details.
|
|
379
|
+
|
|
380
|
+
By default, CAREamics only supports types defined in
|
|
381
|
+
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
382
|
+
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
383
|
+
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
384
|
+
"*.jpeg") to filter the files extension using `extension_filter`.
|
|
385
|
+
|
|
386
|
+
In the absence of validation data, the validation data is extracted from the
|
|
387
|
+
training data. The percentage of the training data to use for validation, as well as
|
|
388
|
+
the minimum number of patches to split from the training data for validation can be
|
|
389
|
+
set using `val_percentage` and `val_minimum_patches`, respectively.
|
|
390
|
+
|
|
391
|
+
In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
|
|
392
|
+
except for `batch_size`, which is set by the `batch_size` parameter.
|
|
393
|
+
|
|
394
|
+
Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
|
|
395
|
+
use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
|
|
396
|
+
the axis and span of the structN2V mask. These parameters are without effect if
|
|
397
|
+
a `train_target_data` or if `transforms` are provided.
|
|
398
|
+
|
|
399
|
+
Parameters
|
|
400
|
+
----------
|
|
401
|
+
train_data : Union[str, Path, np.ndarray]
|
|
402
|
+
Training data.
|
|
403
|
+
data_type : Union[str, SupportedData]
|
|
404
|
+
Data type, see `SupportedData` for available options.
|
|
405
|
+
patch_size : List[int]
|
|
406
|
+
Patch size, 2D or 3D patch size.
|
|
407
|
+
axes : str
|
|
408
|
+
Axes of the data, choosen amongst SCZYX.
|
|
409
|
+
batch_size : int
|
|
410
|
+
Batch size.
|
|
411
|
+
val_data : Optional[Union[str, Path]], optional
|
|
412
|
+
Validation data, by default None.
|
|
413
|
+
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
414
|
+
List of transforms to apply to training patches. If None, default transforms
|
|
415
|
+
are applied.
|
|
416
|
+
train_target_data : Optional[Union[str, Path]], optional
|
|
417
|
+
Training target data, by default None.
|
|
418
|
+
val_target_data : Optional[Union[str, Path]], optional
|
|
419
|
+
Validation target data, by default None.
|
|
420
|
+
read_source_func : Optional[Callable], optional
|
|
421
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
422
|
+
default None.
|
|
423
|
+
extension_filter : str, optional
|
|
424
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
425
|
+
val_percentage : float, optional
|
|
426
|
+
Percentage of the training data to use for validation if no validation data
|
|
427
|
+
is given, by default 0.1.
|
|
428
|
+
val_minimum_patches : int, optional
|
|
429
|
+
Minimum number of patches to split from the training data for validation if
|
|
430
|
+
no validation data is given, by default 5.
|
|
431
|
+
dataloader_params : dict, optional
|
|
432
|
+
Pytorch dataloader parameters, by default {}.
|
|
433
|
+
use_in_memory : bool, optional
|
|
434
|
+
Use in memory dataset if possible, by default True.
|
|
435
|
+
use_n2v2 : bool, optional
|
|
436
|
+
Use N2V2 transformation during training, by default False.
|
|
437
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
438
|
+
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
439
|
+
default "none".
|
|
440
|
+
struct_n2v_span : int, optional
|
|
441
|
+
Span for the structN2V mask, by default 5.
|
|
442
|
+
|
|
443
|
+
Examples
|
|
444
|
+
--------
|
|
445
|
+
Create a CAREamicsTrainDataModule with default transforms with a numpy array:
|
|
446
|
+
>>> import numpy as np
|
|
447
|
+
>>> from careamics import CAREamicsTrainDataModule
|
|
448
|
+
>>> my_array = np.arange(256).reshape(16, 16)
|
|
449
|
+
>>> data_module = CAREamicsTrainDataModule(
|
|
450
|
+
... train_data=my_array,
|
|
451
|
+
... data_type="array",
|
|
452
|
+
... patch_size=(8, 8),
|
|
453
|
+
... axes='YX',
|
|
454
|
+
... batch_size=2,
|
|
455
|
+
... )
|
|
456
|
+
|
|
457
|
+
For custom data types (those not supported by CAREamics), then one can pass a read
|
|
458
|
+
function and a filter for the files extension:
|
|
459
|
+
>>> import numpy as np
|
|
460
|
+
>>> from careamics import CAREamicsTrainDataModule
|
|
461
|
+
>>>
|
|
462
|
+
>>> def read_npy(path):
|
|
463
|
+
... return np.load(path)
|
|
464
|
+
>>>
|
|
465
|
+
>>> data_module = CAREamicsTrainDataModule(
|
|
466
|
+
... train_data="path/to/data",
|
|
467
|
+
... data_type="custom",
|
|
468
|
+
... patch_size=(8, 8),
|
|
469
|
+
... axes='YX',
|
|
470
|
+
... batch_size=2,
|
|
471
|
+
... read_source_func=read_npy,
|
|
472
|
+
... extension_filter="*.npy",
|
|
473
|
+
... )
|
|
474
|
+
|
|
475
|
+
If you want to use a different set of transformations, you can pass a list of
|
|
476
|
+
transforms:
|
|
477
|
+
>>> import numpy as np
|
|
478
|
+
>>> from careamics import CAREamicsTrainDataModule
|
|
479
|
+
>>> from careamics.config.support import SupportedTransform
|
|
480
|
+
>>> my_array = np.arange(256).reshape(16, 16)
|
|
481
|
+
>>> my_transforms = [
|
|
482
|
+
... {
|
|
483
|
+
... "name": SupportedTransform.NORMALIZE.value,
|
|
484
|
+
... "mean": 0,
|
|
485
|
+
... "std": 1,
|
|
486
|
+
... },
|
|
487
|
+
... {
|
|
488
|
+
... "name": SupportedTransform.N2V_MANIPULATE.value,
|
|
489
|
+
... }
|
|
490
|
+
... ]
|
|
491
|
+
>>> data_module = CAREamicsTrainDataModule(
|
|
492
|
+
... train_data=my_array,
|
|
493
|
+
... data_type="array",
|
|
494
|
+
... patch_size=(8, 8),
|
|
495
|
+
... axes='YX',
|
|
496
|
+
... batch_size=2,
|
|
497
|
+
... transforms=my_transforms,
|
|
498
|
+
... )
|
|
499
|
+
"""
|
|
500
|
+
|
|
501
|
+
def __init__(
|
|
502
|
+
self,
|
|
503
|
+
train_data: Union[str, Path, np.ndarray],
|
|
504
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
505
|
+
patch_size: List[int],
|
|
506
|
+
axes: str,
|
|
507
|
+
batch_size: int,
|
|
508
|
+
val_data: Optional[Union[str, Path]] = None,
|
|
509
|
+
transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
|
|
510
|
+
train_target_data: Optional[Union[str, Path]] = None,
|
|
511
|
+
val_target_data: Optional[Union[str, Path]] = None,
|
|
512
|
+
read_source_func: Optional[Callable] = None,
|
|
513
|
+
extension_filter: str = "",
|
|
514
|
+
val_percentage: float = 0.1,
|
|
515
|
+
val_minimum_patches: int = 5,
|
|
516
|
+
dataloader_params: Optional[dict] = None,
|
|
517
|
+
use_in_memory: bool = True,
|
|
518
|
+
use_n2v2: bool = False,
|
|
519
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
520
|
+
struct_n2v_span: int = 5,
|
|
521
|
+
) -> None:
|
|
522
|
+
"""
|
|
523
|
+
LightningDataModule wrapper for training and validation datasets.
|
|
524
|
+
|
|
525
|
+
Since the lightning datamodule has no access to the model, make sure that the
|
|
526
|
+
parameters passed to the datamodule are consistent with the model's requirements
|
|
527
|
+
and are coherent.
|
|
528
|
+
|
|
529
|
+
The data module can be used with Path, str or numpy arrays. In the case of
|
|
530
|
+
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
531
|
+
inputs, it calculates the total file size and estimate whether it can fit in
|
|
532
|
+
memory. If it does not, it iterates through the files. This behaviour can be
|
|
533
|
+
deactivated by setting `use_in_memory` to False, in which case it will
|
|
534
|
+
always use the iterating dataset to train on a Path or str.
|
|
535
|
+
|
|
536
|
+
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
537
|
+
`train_data`.
|
|
538
|
+
|
|
539
|
+
In particular, N2V requires a specific transformation (N2V manipulates), which
|
|
540
|
+
is not compatible with supervised training. The default transformations applied
|
|
541
|
+
to the training patches are defined in `careamics.config.data_model`. To use
|
|
542
|
+
different transformations, pass a list of transforms or an albumentation
|
|
543
|
+
`Compose` as `transforms` parameter. See examples for more details.
|
|
544
|
+
|
|
545
|
+
By default, CAREamics only supports types defined in
|
|
546
|
+
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
547
|
+
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
548
|
+
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
|
|
549
|
+
(e.g. "*.jpeg") to filter the files extension using `extension_filter`.
|
|
550
|
+
|
|
551
|
+
In the absence of validation data, the validation data is extracted from the
|
|
552
|
+
training data. The percentage of the training data to use for validation, as
|
|
553
|
+
well as the minimum number of patches to split from the training data for
|
|
554
|
+
validation can be set using `val_percentage` and `val_minimum_patches`,
|
|
555
|
+
respectively.
|
|
556
|
+
|
|
557
|
+
In `dataloader_params`, you can pass any parameter accepted by PyTorch
|
|
558
|
+
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
559
|
+
parameter.
|
|
560
|
+
|
|
561
|
+
Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
|
|
562
|
+
to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
|
|
563
|
+
define the axis and span of the structN2V mask. These parameters are without
|
|
564
|
+
effect if a `train_target_data` or if `transforms` are provided.
|
|
565
|
+
|
|
566
|
+
Parameters
|
|
567
|
+
----------
|
|
568
|
+
train_data : Union[str, Path, np.ndarray]
|
|
569
|
+
Training data.
|
|
570
|
+
data_type : Union[str, SupportedData]
|
|
571
|
+
Data type, see `SupportedData` for available options.
|
|
572
|
+
patch_size : List[int]
|
|
573
|
+
Patch size, 2D or 3D patch size.
|
|
574
|
+
axes : str
|
|
575
|
+
Axes of the data, choosen amongst SCZYX.
|
|
576
|
+
batch_size : int
|
|
577
|
+
Batch size.
|
|
578
|
+
val_data : Optional[Union[str, Path]], optional
|
|
579
|
+
Validation data, by default None.
|
|
580
|
+
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
581
|
+
List of transforms to apply to training patches. If None, default transforms
|
|
582
|
+
are applied.
|
|
583
|
+
train_target_data : Optional[Union[str, Path]], optional
|
|
584
|
+
Training target data, by default None.
|
|
585
|
+
val_target_data : Optional[Union[str, Path]], optional
|
|
586
|
+
Validation target data, by default None.
|
|
587
|
+
read_source_func : Optional[Callable], optional
|
|
588
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
589
|
+
default None.
|
|
590
|
+
extension_filter : str, optional
|
|
591
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
592
|
+
val_percentage : float, optional
|
|
593
|
+
Percentage of the training data to use for validation if no validation data
|
|
594
|
+
is given, by default 0.1.
|
|
595
|
+
val_minimum_patches : int, optional
|
|
596
|
+
Minimum number of patches to split from the training data for validation if
|
|
597
|
+
no validation data is given, by default 5.
|
|
598
|
+
dataloader_params : dict, optional
|
|
599
|
+
Pytorch dataloader parameters, by default {}.
|
|
600
|
+
use_in_memory : bool, optional
|
|
601
|
+
Use in memory dataset if possible, by default True.
|
|
602
|
+
use_n2v2 : bool, optional
|
|
603
|
+
Use N2V2 transformation during training, by default False.
|
|
604
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
605
|
+
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
606
|
+
default "none".
|
|
607
|
+
struct_n2v_span : int, optional
|
|
608
|
+
Span for the structN2V mask, by default 5.
|
|
609
|
+
|
|
610
|
+
Raises
|
|
611
|
+
------
|
|
612
|
+
ValueError
|
|
613
|
+
If a target is set and N2V manipulation is present in the transforms.
|
|
614
|
+
"""
|
|
615
|
+
if dataloader_params is None:
|
|
616
|
+
dataloader_params = {}
|
|
617
|
+
data_dict: Dict[str, Any] = {
|
|
618
|
+
"mode": "train",
|
|
619
|
+
"data_type": data_type,
|
|
620
|
+
"patch_size": patch_size,
|
|
621
|
+
"axes": axes,
|
|
622
|
+
"batch_size": batch_size,
|
|
623
|
+
"dataloader_params": dataloader_params,
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
# if transforms are passed (otherwise it will use the default ones)
|
|
627
|
+
if transforms is not None:
|
|
628
|
+
data_dict["transforms"] = transforms
|
|
629
|
+
|
|
630
|
+
# validate configuration
|
|
631
|
+
self.data_config = DataModel(**data_dict)
|
|
632
|
+
|
|
633
|
+
# N2V specific checks, N2V, structN2V, and transforms
|
|
634
|
+
if (
|
|
635
|
+
self.data_config.has_transform_list()
|
|
636
|
+
and self.data_config.has_n2v_manipulate()
|
|
637
|
+
):
|
|
638
|
+
# there is not target, n2v2 and structN2V can be changed
|
|
639
|
+
if train_target_data is None:
|
|
640
|
+
self.data_config.set_N2V2(use_n2v2)
|
|
641
|
+
self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
642
|
+
else:
|
|
643
|
+
raise ValueError(
|
|
644
|
+
"Cannot have both supervised training (target data) and "
|
|
645
|
+
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
646
|
+
"that is compatible with your supervised training."
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# sanity check on the dataloader parameters
|
|
650
|
+
if "batch_size" in dataloader_params:
|
|
651
|
+
# remove it
|
|
652
|
+
del dataloader_params["batch_size"]
|
|
653
|
+
|
|
654
|
+
super().__init__(
|
|
655
|
+
data_config=self.data_config,
|
|
656
|
+
train_data=train_data,
|
|
657
|
+
val_data=val_data,
|
|
658
|
+
train_data_target=train_target_data,
|
|
659
|
+
val_data_target=val_target_data,
|
|
660
|
+
read_source_func=read_source_func,
|
|
661
|
+
extension_filter=extension_filter,
|
|
662
|
+
val_percentage=val_percentage,
|
|
663
|
+
val_minimum_split=val_minimum_patches,
|
|
664
|
+
use_in_memory=use_in_memory,
|
|
665
|
+
)
|