careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,680 @@
|
|
|
1
|
+
"""Training and validation Lightning data modules."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytorch_lightning as L
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig
|
|
12
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
13
|
+
from careamics.config.support import SupportedData
|
|
14
|
+
from careamics.dataset.dataset_utils import (
|
|
15
|
+
get_files_size,
|
|
16
|
+
list_files,
|
|
17
|
+
validate_source_target_files,
|
|
18
|
+
)
|
|
19
|
+
from careamics.dataset.in_memory_dataset import (
|
|
20
|
+
InMemoryDataset,
|
|
21
|
+
)
|
|
22
|
+
from careamics.dataset.iterable_dataset import (
|
|
23
|
+
PathIterableDataset,
|
|
24
|
+
)
|
|
25
|
+
from careamics.file_io.read import get_read_func
|
|
26
|
+
from careamics.utils import get_logger, get_ram_size
|
|
27
|
+
|
|
28
|
+
DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
29
|
+
|
|
30
|
+
logger = get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TrainDataModule(L.LightningDataModule):
|
|
34
|
+
"""
|
|
35
|
+
CAREamics Ligthning training and validation data module.
|
|
36
|
+
|
|
37
|
+
The data module can be used with Path, str or numpy arrays. In the case of
|
|
38
|
+
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
39
|
+
inputs, it calculates the total file size and estimate whether it can fit in
|
|
40
|
+
memory. If it does not, it iterates through the files. This behaviour can be
|
|
41
|
+
deactivated by setting `use_in_memory` to False, in which case it will
|
|
42
|
+
always use the iterating dataset to train on a Path or str.
|
|
43
|
+
|
|
44
|
+
The data can be either a folder containing images or a single file.
|
|
45
|
+
|
|
46
|
+
Validation can be omitted, in which case the validation data is extracted from
|
|
47
|
+
the training data. The percentage of the training data to use for validation,
|
|
48
|
+
as well as the minimum number of patches or files to split from the training
|
|
49
|
+
data can be set using `val_percentage` and `val_minimum_split`, respectively.
|
|
50
|
+
|
|
51
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
52
|
+
and provide a function that returns a numpy array from a path as
|
|
53
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
54
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
55
|
+
|
|
56
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
57
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
data_config : DataModel
|
|
62
|
+
Pydantic model for CAREamics data configuration.
|
|
63
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
64
|
+
Training data, can be a path to a folder, a file or a numpy array.
|
|
65
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
66
|
+
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
67
|
+
default None.
|
|
68
|
+
train_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
69
|
+
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
70
|
+
default None.
|
|
71
|
+
val_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
72
|
+
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
73
|
+
by default None.
|
|
74
|
+
read_source_func : Callable, optional
|
|
75
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
76
|
+
data type (see DataModel).
|
|
77
|
+
extension_filter : str, optional
|
|
78
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
79
|
+
(see DataModel).
|
|
80
|
+
val_percentage : float, optional
|
|
81
|
+
Percentage of the training data to use for validation, by default 0.1. Only
|
|
82
|
+
used if `val_data` is None.
|
|
83
|
+
val_minimum_split : int, optional
|
|
84
|
+
Minimum number of patches or files to split from the training data for
|
|
85
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
86
|
+
use_in_memory : bool, optional
|
|
87
|
+
Use in memory dataset if possible, by default True.
|
|
88
|
+
|
|
89
|
+
Attributes
|
|
90
|
+
----------
|
|
91
|
+
data_config : DataModel
|
|
92
|
+
CAREamics data configuration.
|
|
93
|
+
data_type : SupportedData
|
|
94
|
+
Expected data type, one of "tiff", "array" or "custom".
|
|
95
|
+
batch_size : int
|
|
96
|
+
Batch size.
|
|
97
|
+
use_in_memory : bool
|
|
98
|
+
Whether to use in memory dataset if possible.
|
|
99
|
+
train_data : pathlib.Path or numpy.ndarray
|
|
100
|
+
Training data.
|
|
101
|
+
val_data : pathlib.Path or numpy.ndarray
|
|
102
|
+
Validation data.
|
|
103
|
+
train_data_target : pathlib.Path or numpy.ndarray
|
|
104
|
+
Training target data.
|
|
105
|
+
val_data_target : pathlib.Path or numpy.ndarray
|
|
106
|
+
Validation target data.
|
|
107
|
+
val_percentage : float
|
|
108
|
+
Percentage of the training data to use for validation, if no validation data is
|
|
109
|
+
provided.
|
|
110
|
+
val_minimum_split : int
|
|
111
|
+
Minimum number of patches or files to split from the training data for
|
|
112
|
+
validation, if no validation data is provided.
|
|
113
|
+
read_source_func : Optional[Callable]
|
|
114
|
+
Function to read the source data, used if `data_type` is `custom`.
|
|
115
|
+
extension_filter : str
|
|
116
|
+
Filter for file extensions, used if `data_type` is `custom`.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
data_config: DataConfig,
|
|
122
|
+
train_data: Union[Path, str, NDArray],
|
|
123
|
+
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
124
|
+
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
125
|
+
val_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
126
|
+
read_source_func: Optional[Callable] = None,
|
|
127
|
+
extension_filter: str = "",
|
|
128
|
+
val_percentage: float = 0.1,
|
|
129
|
+
val_minimum_split: int = 5,
|
|
130
|
+
use_in_memory: bool = True,
|
|
131
|
+
) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Constructor.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
data_config : DataModel
|
|
138
|
+
Pydantic model for CAREamics data configuration.
|
|
139
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
140
|
+
Training data, can be a path to a folder, a file or a numpy array.
|
|
141
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
142
|
+
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
143
|
+
default None.
|
|
144
|
+
train_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
145
|
+
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
146
|
+
default None.
|
|
147
|
+
val_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
148
|
+
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
149
|
+
by default None.
|
|
150
|
+
read_source_func : Callable, optional
|
|
151
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
152
|
+
data type (see DataModel).
|
|
153
|
+
extension_filter : str, optional
|
|
154
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
155
|
+
(see DataModel).
|
|
156
|
+
val_percentage : float, optional
|
|
157
|
+
Percentage of the training data to use for validation, by default 0.1. Only
|
|
158
|
+
used if `val_data` is None.
|
|
159
|
+
val_minimum_split : int, optional
|
|
160
|
+
Minimum number of patches or files to split from the training data for
|
|
161
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
162
|
+
use_in_memory : bool, optional
|
|
163
|
+
Use in memory dataset if possible, by default True.
|
|
164
|
+
|
|
165
|
+
Raises
|
|
166
|
+
------
|
|
167
|
+
NotImplementedError
|
|
168
|
+
Raised if target data is provided.
|
|
169
|
+
ValueError
|
|
170
|
+
If the input types are mixed (e.g. Path and numpy.ndarray).
|
|
171
|
+
ValueError
|
|
172
|
+
If the data type is `custom` and no `read_source_func` is provided.
|
|
173
|
+
ValueError
|
|
174
|
+
If the data type is `array` and the input is not a numpy array.
|
|
175
|
+
ValueError
|
|
176
|
+
If the data type is `tiff` and the input is neither a Path nor a str.
|
|
177
|
+
"""
|
|
178
|
+
super().__init__()
|
|
179
|
+
|
|
180
|
+
# check input types coherence (no mixed types)
|
|
181
|
+
inputs = [train_data, val_data, train_data_target, val_data_target]
|
|
182
|
+
types_set = {type(i) for i in inputs}
|
|
183
|
+
if len(types_set) > 2: # None + expected type
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Inputs for `train_data`, `val_data`, `train_data_target` and "
|
|
186
|
+
f"`val_data_target` must be of the same type or None. Got "
|
|
187
|
+
f"{types_set}."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# check that a read source function is provided for custom types
|
|
191
|
+
if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
194
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# check correct input type
|
|
198
|
+
if (
|
|
199
|
+
isinstance(train_data, np.ndarray)
|
|
200
|
+
and data_config.data_type != SupportedData.ARRAY
|
|
201
|
+
):
|
|
202
|
+
raise ValueError(
|
|
203
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
204
|
+
f"{data_config.data_type}. Set the data type in the configuration "
|
|
205
|
+
f"to {SupportedData.ARRAY} to train on numpy arrays."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# and that Path or str are passed, if tiff file type specified
|
|
209
|
+
elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
|
|
210
|
+
data_config.data_type != SupportedData.TIFF
|
|
211
|
+
and data_config.data_type != SupportedData.CUSTOM
|
|
212
|
+
):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"Received a path as input, but the data type was neither set to "
|
|
215
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
216
|
+
f"in the configuration to {SupportedData.TIFF} or "
|
|
217
|
+
f"{SupportedData.CUSTOM} to train on files."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# configuration
|
|
221
|
+
self.data_config: DataConfig = data_config
|
|
222
|
+
self.data_type: str = data_config.data_type
|
|
223
|
+
self.batch_size: int = data_config.batch_size
|
|
224
|
+
self.use_in_memory: bool = use_in_memory
|
|
225
|
+
|
|
226
|
+
# data: make data Path or np.ndarray, use type annotations for mypy
|
|
227
|
+
self.train_data: Union[Path, NDArray] = (
|
|
228
|
+
Path(train_data) if isinstance(train_data, str) else train_data
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
self.val_data: Union[Path, NDArray] = (
|
|
232
|
+
Path(val_data) if isinstance(val_data, str) else val_data
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self.train_data_target: Union[Path, NDArray] = (
|
|
236
|
+
Path(train_data_target)
|
|
237
|
+
if isinstance(train_data_target, str)
|
|
238
|
+
else train_data_target
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
self.val_data_target: Union[Path, NDArray] = (
|
|
242
|
+
Path(val_data_target)
|
|
243
|
+
if isinstance(val_data_target, str)
|
|
244
|
+
else val_data_target
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# validation split
|
|
248
|
+
self.val_percentage = val_percentage
|
|
249
|
+
self.val_minimum_split = val_minimum_split
|
|
250
|
+
|
|
251
|
+
# read source function corresponding to the requested type
|
|
252
|
+
if data_config.data_type == SupportedData.CUSTOM.value:
|
|
253
|
+
# mypy check
|
|
254
|
+
assert read_source_func is not None
|
|
255
|
+
|
|
256
|
+
self.read_source_func: Callable = read_source_func
|
|
257
|
+
|
|
258
|
+
elif data_config.data_type != SupportedData.ARRAY:
|
|
259
|
+
self.read_source_func = get_read_func(data_config.data_type)
|
|
260
|
+
|
|
261
|
+
self.extension_filter: str = extension_filter
|
|
262
|
+
|
|
263
|
+
# Pytorch dataloader parameters
|
|
264
|
+
self.dataloader_params: dict[str, Any] = (
|
|
265
|
+
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def prepare_data(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Hook used to prepare the data before calling `setup`.
|
|
271
|
+
|
|
272
|
+
Here, we only need to examine the data if it was provided as a str or a Path.
|
|
273
|
+
|
|
274
|
+
TODO: from lightning doc:
|
|
275
|
+
prepare_data is called from the main process. It is not recommended to assign
|
|
276
|
+
state here (e.g. self.x = y) since it is called on a single process and if you
|
|
277
|
+
assign states here then they won't be available for other processes.
|
|
278
|
+
|
|
279
|
+
https://lightning.ai/docs/pytorch/stable/data/datamodule.html
|
|
280
|
+
"""
|
|
281
|
+
# if the data is a Path or a str
|
|
282
|
+
if (
|
|
283
|
+
not isinstance(self.train_data, np.ndarray)
|
|
284
|
+
and not isinstance(self.val_data, np.ndarray)
|
|
285
|
+
and not isinstance(self.train_data_target, np.ndarray)
|
|
286
|
+
and not isinstance(self.val_data_target, np.ndarray)
|
|
287
|
+
):
|
|
288
|
+
# list training files
|
|
289
|
+
self.train_files = list_files(
|
|
290
|
+
self.train_data, self.data_type, self.extension_filter
|
|
291
|
+
)
|
|
292
|
+
self.train_files_size = get_files_size(self.train_files)
|
|
293
|
+
|
|
294
|
+
# list validation files
|
|
295
|
+
if self.val_data is not None:
|
|
296
|
+
self.val_files = list_files(
|
|
297
|
+
self.val_data, self.data_type, self.extension_filter
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# same for target data
|
|
301
|
+
if self.train_data_target is not None:
|
|
302
|
+
self.train_target_files: list[Path] = list_files(
|
|
303
|
+
self.train_data_target, self.data_type, self.extension_filter
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# verify that they match the training data
|
|
307
|
+
validate_source_target_files(self.train_files, self.train_target_files)
|
|
308
|
+
|
|
309
|
+
if self.val_data_target is not None:
|
|
310
|
+
self.val_target_files = list_files(
|
|
311
|
+
self.val_data_target, self.data_type, self.extension_filter
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# verify that they match the validation data
|
|
315
|
+
validate_source_target_files(self.val_files, self.val_target_files)
|
|
316
|
+
|
|
317
|
+
def setup(self, *args: Any, **kwargs: Any) -> None:
|
|
318
|
+
"""Hook called at the beginning of fit, validate, or predict.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
*args : Any
|
|
323
|
+
Unused.
|
|
324
|
+
**kwargs : Any
|
|
325
|
+
Unused.
|
|
326
|
+
"""
|
|
327
|
+
# if numpy array
|
|
328
|
+
if self.data_type == SupportedData.ARRAY:
|
|
329
|
+
# mypy checks
|
|
330
|
+
assert isinstance(self.train_data, np.ndarray)
|
|
331
|
+
if self.train_data_target is not None:
|
|
332
|
+
assert isinstance(self.train_data_target, np.ndarray)
|
|
333
|
+
|
|
334
|
+
# train dataset
|
|
335
|
+
self.train_dataset: DatasetType = InMemoryDataset(
|
|
336
|
+
data_config=self.data_config,
|
|
337
|
+
inputs=self.train_data,
|
|
338
|
+
input_target=self.train_data_target,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# validation dataset
|
|
342
|
+
if self.val_data is not None:
|
|
343
|
+
# mypy checks
|
|
344
|
+
assert isinstance(self.val_data, np.ndarray)
|
|
345
|
+
if self.val_data_target is not None:
|
|
346
|
+
assert isinstance(self.val_data_target, np.ndarray)
|
|
347
|
+
|
|
348
|
+
# create its own dataset
|
|
349
|
+
self.val_dataset: DatasetType = InMemoryDataset(
|
|
350
|
+
data_config=self.data_config,
|
|
351
|
+
inputs=self.val_data,
|
|
352
|
+
input_target=self.val_data_target,
|
|
353
|
+
)
|
|
354
|
+
else:
|
|
355
|
+
# extract validation from the training patches
|
|
356
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
357
|
+
percentage=self.val_percentage,
|
|
358
|
+
minimum_patches=self.val_minimum_split,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# else we read files
|
|
362
|
+
else:
|
|
363
|
+
# Heuristics, if the file size is smaller than 80% of the RAM,
|
|
364
|
+
# we run the training in memory, otherwise we switch to iterable dataset
|
|
365
|
+
# The switch is deactivated if use_in_memory is False
|
|
366
|
+
if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
|
|
367
|
+
# train dataset
|
|
368
|
+
self.train_dataset = InMemoryDataset(
|
|
369
|
+
data_config=self.data_config,
|
|
370
|
+
inputs=self.train_files,
|
|
371
|
+
input_target=(
|
|
372
|
+
self.train_target_files if self.train_data_target else None
|
|
373
|
+
),
|
|
374
|
+
read_source_func=self.read_source_func,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# validation dataset
|
|
378
|
+
if self.val_data is not None:
|
|
379
|
+
self.val_dataset = InMemoryDataset(
|
|
380
|
+
data_config=self.data_config,
|
|
381
|
+
inputs=self.val_files,
|
|
382
|
+
input_target=(
|
|
383
|
+
self.val_target_files if self.val_data_target else None
|
|
384
|
+
),
|
|
385
|
+
read_source_func=self.read_source_func,
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
# split dataset
|
|
389
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
390
|
+
percentage=self.val_percentage,
|
|
391
|
+
minimum_patches=self.val_minimum_split,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# else if the data is too large, load file by file during training
|
|
395
|
+
else:
|
|
396
|
+
# create training dataset
|
|
397
|
+
self.train_dataset = PathIterableDataset(
|
|
398
|
+
data_config=self.data_config,
|
|
399
|
+
src_files=self.train_files,
|
|
400
|
+
target_files=(
|
|
401
|
+
self.train_target_files if self.train_data_target else None
|
|
402
|
+
),
|
|
403
|
+
read_source_func=self.read_source_func,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# create validation dataset
|
|
407
|
+
if self.val_data is not None:
|
|
408
|
+
# create its own dataset
|
|
409
|
+
self.val_dataset = PathIterableDataset(
|
|
410
|
+
data_config=self.data_config,
|
|
411
|
+
src_files=self.val_files,
|
|
412
|
+
target_files=(
|
|
413
|
+
self.val_target_files if self.val_data_target else None
|
|
414
|
+
),
|
|
415
|
+
read_source_func=self.read_source_func,
|
|
416
|
+
)
|
|
417
|
+
elif len(self.train_files) <= self.val_minimum_split:
|
|
418
|
+
raise ValueError(
|
|
419
|
+
f"Not enough files to split a minimum of "
|
|
420
|
+
f"{self.val_minimum_split} files, got {len(self.train_files)} "
|
|
421
|
+
f"files."
|
|
422
|
+
)
|
|
423
|
+
else:
|
|
424
|
+
# extract validation from the training patches
|
|
425
|
+
self.val_dataset = self.train_dataset.split_dataset(
|
|
426
|
+
percentage=self.val_percentage,
|
|
427
|
+
minimum_number=self.val_minimum_split,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
431
|
+
"""Return training data statistics.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
tuple of list
|
|
436
|
+
Means and standard deviations across channels of the training data.
|
|
437
|
+
"""
|
|
438
|
+
return self.train_dataset.get_data_statistics()
|
|
439
|
+
|
|
440
|
+
def train_dataloader(self) -> Any:
|
|
441
|
+
"""
|
|
442
|
+
Create a dataloader for training.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
Any
|
|
447
|
+
Training dataloader.
|
|
448
|
+
"""
|
|
449
|
+
return DataLoader(
|
|
450
|
+
self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
def val_dataloader(self) -> Any:
|
|
454
|
+
"""
|
|
455
|
+
Create a dataloader for validation.
|
|
456
|
+
|
|
457
|
+
Returns
|
|
458
|
+
-------
|
|
459
|
+
Any
|
|
460
|
+
Validation dataloader.
|
|
461
|
+
"""
|
|
462
|
+
return DataLoader(
|
|
463
|
+
self.val_dataset,
|
|
464
|
+
batch_size=self.batch_size,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def create_train_datamodule(
|
|
469
|
+
train_data: Union[str, Path, NDArray],
|
|
470
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
471
|
+
patch_size: list[int],
|
|
472
|
+
axes: str,
|
|
473
|
+
batch_size: int,
|
|
474
|
+
val_data: Optional[Union[str, Path, NDArray]] = None,
|
|
475
|
+
transforms: Optional[list[TRANSFORMS_UNION]] = None,
|
|
476
|
+
train_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
477
|
+
val_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
478
|
+
read_source_func: Optional[Callable] = None,
|
|
479
|
+
extension_filter: str = "",
|
|
480
|
+
val_percentage: float = 0.1,
|
|
481
|
+
val_minimum_patches: int = 5,
|
|
482
|
+
dataloader_params: Optional[dict] = None,
|
|
483
|
+
use_in_memory: bool = True,
|
|
484
|
+
use_n2v2: bool = False,
|
|
485
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
486
|
+
struct_n2v_span: int = 5,
|
|
487
|
+
) -> TrainDataModule:
|
|
488
|
+
"""Create a TrainDataModule.
|
|
489
|
+
|
|
490
|
+
This function is used to explicitly pass the parameters usually contained in a
|
|
491
|
+
`data_model` configuration to a TrainDataModule.
|
|
492
|
+
|
|
493
|
+
Since the lightning datamodule has no access to the model, make sure that the
|
|
494
|
+
parameters passed to the datamodule are consistent with the model's requirements and
|
|
495
|
+
are coherent.
|
|
496
|
+
|
|
497
|
+
The data module can be used with Path, str or numpy arrays. In the case of
|
|
498
|
+
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
499
|
+
inputs, it calculates the total file size and estimate whether it can fit in
|
|
500
|
+
memory. If it does not, it iterates through the files. This behaviour can be
|
|
501
|
+
deactivated by setting `use_in_memory` to False, in which case it will
|
|
502
|
+
always use the iterating dataset to train on a Path or str.
|
|
503
|
+
|
|
504
|
+
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
505
|
+
`train_data`.
|
|
506
|
+
|
|
507
|
+
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
508
|
+
not compatible with supervised training. The default transformations applied to the
|
|
509
|
+
training patches are defined in `careamics.config.data_model`. To use different
|
|
510
|
+
transformations, pass a list of transforms. See examples for more details.
|
|
511
|
+
|
|
512
|
+
By default, CAREamics only supports types defined in
|
|
513
|
+
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
514
|
+
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
515
|
+
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
516
|
+
"*.jpeg") to filter the files extension using `extension_filter`.
|
|
517
|
+
|
|
518
|
+
In the absence of validation data, the validation data is extracted from the
|
|
519
|
+
training data. The percentage of the training data to use for validation, as well as
|
|
520
|
+
the minimum number of patches to split from the training data for validation can be
|
|
521
|
+
set using `val_percentage` and `val_minimum_patches`, respectively.
|
|
522
|
+
|
|
523
|
+
In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
|
|
524
|
+
except for `batch_size`, which is set by the `batch_size` parameter.
|
|
525
|
+
|
|
526
|
+
Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
|
|
527
|
+
use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
|
|
528
|
+
the axis and span of the structN2V mask. These parameters are without effect if
|
|
529
|
+
a `train_target_data` or if `transforms` are provided.
|
|
530
|
+
|
|
531
|
+
Parameters
|
|
532
|
+
----------
|
|
533
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
534
|
+
Training data.
|
|
535
|
+
data_type : {"array", "tiff", "custom"}
|
|
536
|
+
Data type, see `SupportedData` for available options.
|
|
537
|
+
patch_size : list of int
|
|
538
|
+
Patch size, 2D or 3D patch size.
|
|
539
|
+
axes : str
|
|
540
|
+
Axes of the data, chosen amongst SCZYX.
|
|
541
|
+
batch_size : int
|
|
542
|
+
Batch size.
|
|
543
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
544
|
+
Validation data, by default None.
|
|
545
|
+
transforms : list of Transforms, optional
|
|
546
|
+
List of transforms to apply to training patches. If None, default transforms
|
|
547
|
+
are applied.
|
|
548
|
+
train_target_data : pathlib.Path or str or numpy.ndarray, optional
|
|
549
|
+
Training target data, by default None.
|
|
550
|
+
val_target_data : pathlib.Path or str or numpy.ndarray, optional
|
|
551
|
+
Validation target data, by default None.
|
|
552
|
+
read_source_func : Callable, optional
|
|
553
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
554
|
+
default None.
|
|
555
|
+
extension_filter : str, optional
|
|
556
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
557
|
+
val_percentage : float, optional
|
|
558
|
+
Percentage of the training data to use for validation if no validation data
|
|
559
|
+
is given, by default 0.1.
|
|
560
|
+
val_minimum_patches : int, optional
|
|
561
|
+
Minimum number of patches to split from the training data for validation if
|
|
562
|
+
no validation data is given, by default 5.
|
|
563
|
+
dataloader_params : dict, optional
|
|
564
|
+
Pytorch dataloader parameters, by default {}.
|
|
565
|
+
use_in_memory : bool, optional
|
|
566
|
+
Use in memory dataset if possible, by default True.
|
|
567
|
+
use_n2v2 : bool, optional
|
|
568
|
+
Use N2V2 transformation during training, by default False.
|
|
569
|
+
struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
|
|
570
|
+
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
571
|
+
default "none".
|
|
572
|
+
struct_n2v_span : int, optional
|
|
573
|
+
Span for the structN2V mask, by default 5.
|
|
574
|
+
|
|
575
|
+
Returns
|
|
576
|
+
-------
|
|
577
|
+
TrainDataModule
|
|
578
|
+
CAREamics training Lightning data module.
|
|
579
|
+
|
|
580
|
+
Examples
|
|
581
|
+
--------
|
|
582
|
+
Create a TrainingDataModule with default transforms with a numpy array:
|
|
583
|
+
>>> import numpy as np
|
|
584
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
585
|
+
>>> my_array = np.arange(256).reshape(16, 16)
|
|
586
|
+
>>> data_module = create_train_datamodule(
|
|
587
|
+
... train_data=my_array,
|
|
588
|
+
... data_type="array",
|
|
589
|
+
... patch_size=(8, 8),
|
|
590
|
+
... axes='YX',
|
|
591
|
+
... batch_size=2,
|
|
592
|
+
... )
|
|
593
|
+
|
|
594
|
+
For custom data types (those not supported by CAREamics), then one can pass a read
|
|
595
|
+
function and a filter for the files extension:
|
|
596
|
+
>>> import numpy as np
|
|
597
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
598
|
+
>>>
|
|
599
|
+
>>> def read_npy(path):
|
|
600
|
+
... return np.load(path)
|
|
601
|
+
>>>
|
|
602
|
+
>>> data_module = create_train_datamodule(
|
|
603
|
+
... train_data="path/to/data",
|
|
604
|
+
... data_type="custom",
|
|
605
|
+
... patch_size=(8, 8),
|
|
606
|
+
... axes='YX',
|
|
607
|
+
... batch_size=2,
|
|
608
|
+
... read_source_func=read_npy,
|
|
609
|
+
... extension_filter="*.npy",
|
|
610
|
+
... )
|
|
611
|
+
|
|
612
|
+
If you want to use a different set of transformations, you can pass a list of
|
|
613
|
+
transforms:
|
|
614
|
+
>>> import numpy as np
|
|
615
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
616
|
+
>>> from careamics.config.support import SupportedTransform
|
|
617
|
+
>>> my_array = np.arange(256).reshape(16, 16)
|
|
618
|
+
>>> my_transforms = [
|
|
619
|
+
... {
|
|
620
|
+
... "name": SupportedTransform.XY_FLIP.value,
|
|
621
|
+
... }
|
|
622
|
+
... ]
|
|
623
|
+
>>> data_module = create_train_datamodule(
|
|
624
|
+
... train_data=my_array,
|
|
625
|
+
... data_type="array",
|
|
626
|
+
... patch_size=(8, 8),
|
|
627
|
+
... axes='YX',
|
|
628
|
+
... batch_size=2,
|
|
629
|
+
... transforms=my_transforms,
|
|
630
|
+
... )
|
|
631
|
+
"""
|
|
632
|
+
if dataloader_params is None:
|
|
633
|
+
dataloader_params = {}
|
|
634
|
+
|
|
635
|
+
data_dict: dict[str, Any] = {
|
|
636
|
+
"mode": "train",
|
|
637
|
+
"data_type": data_type,
|
|
638
|
+
"patch_size": patch_size,
|
|
639
|
+
"axes": axes,
|
|
640
|
+
"batch_size": batch_size,
|
|
641
|
+
"dataloader_params": dataloader_params,
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
# if transforms are passed (otherwise it will use the default ones)
|
|
645
|
+
if transforms is not None:
|
|
646
|
+
data_dict["transforms"] = transforms
|
|
647
|
+
|
|
648
|
+
# validate configuration
|
|
649
|
+
data_config = DataConfig(**data_dict)
|
|
650
|
+
|
|
651
|
+
# N2V specific checks, N2V, structN2V, and transforms
|
|
652
|
+
if data_config.has_n2v_manipulate():
|
|
653
|
+
# there is not target, n2v2 and structN2V can be changed
|
|
654
|
+
if train_target_data is None:
|
|
655
|
+
data_config.set_N2V2(use_n2v2)
|
|
656
|
+
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
657
|
+
else:
|
|
658
|
+
raise ValueError(
|
|
659
|
+
"Cannot have both supervised training (target data) and "
|
|
660
|
+
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
661
|
+
"that is compatible with your supervised training."
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# sanity check on the dataloader parameters
|
|
665
|
+
if "batch_size" in dataloader_params:
|
|
666
|
+
# remove it
|
|
667
|
+
del dataloader_params["batch_size"]
|
|
668
|
+
|
|
669
|
+
return TrainDataModule(
|
|
670
|
+
data_config=data_config,
|
|
671
|
+
train_data=train_data,
|
|
672
|
+
val_data=val_data,
|
|
673
|
+
train_data_target=train_target_data,
|
|
674
|
+
val_data_target=val_target_data,
|
|
675
|
+
read_source_func=read_source_func,
|
|
676
|
+
extension_filter=extension_filter,
|
|
677
|
+
val_percentage=val_percentage,
|
|
678
|
+
val_minimum_split=val_minimum_patches,
|
|
679
|
+
use_in_memory=use_in_memory,
|
|
680
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Losses module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"loss_factory",
|
|
5
|
+
"mae_loss",
|
|
6
|
+
"mse_loss",
|
|
7
|
+
"n2v_loss",
|
|
8
|
+
"denoisplit_loss",
|
|
9
|
+
"musplit_loss",
|
|
10
|
+
"denoisplit_musplit_loss",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
14
|
+
from .loss_factory import loss_factory
|
|
15
|
+
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""FCN losses."""
|