careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +24 -7
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +55 -4
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +41 -4
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/optimizer_models.py +1 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/training_model.py +0 -2
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +229 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +451 -0
- careamics/dataset_ng/legacy_interoperability.py +170 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +678 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
- careamics/lightning/lightning_module.py +5 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/compose.py +1 -0
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/normalize.py +18 -7
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +25 -11
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
"""Next-Generation CAREamics DataModule."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Optional, Union, overload
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pytorch_lightning as L
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from torch.utils.data._utils.collate import default_collate
|
|
12
|
+
|
|
13
|
+
from careamics.config.data.ng_data_model import NGDataConfig
|
|
14
|
+
from careamics.config.support import SupportedData
|
|
15
|
+
from careamics.dataset.dataset_utils import list_files, validate_source_target_files
|
|
16
|
+
from careamics.dataset_ng.dataset import Mode
|
|
17
|
+
from careamics.dataset_ng.factory import create_dataset
|
|
18
|
+
from careamics.dataset_ng.patch_extractor import ImageStackLoader
|
|
19
|
+
from careamics.utils import get_logger
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
ItemType = Union[Path, str, NDArray[Any]]
|
|
24
|
+
"""Type of input items passed to the dataset."""
|
|
25
|
+
|
|
26
|
+
InputType = Union[ItemType, list[ItemType], None]
|
|
27
|
+
"""Type of input data passed to the dataset."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CareamicsDataModule(L.LightningDataModule):
|
|
31
|
+
"""Data module for Careamics dataset.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : DataConfig
|
|
36
|
+
Pydantic model for CAREamics data configuration.
|
|
37
|
+
train_data : Optional[InputType]
|
|
38
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
39
|
+
train_data_target : Optional[InputType]
|
|
40
|
+
Training data target, can be a path to a folder,
|
|
41
|
+
a list of paths, or a numpy array.
|
|
42
|
+
val_data : Optional[InputType]
|
|
43
|
+
Validation data, can be a path to a folder,
|
|
44
|
+
a list of paths, or a numpy array.
|
|
45
|
+
val_data_target : Optional[InputType]
|
|
46
|
+
Validation data target, can be a path to a folder,
|
|
47
|
+
a list of paths, or a numpy array.
|
|
48
|
+
pred_data : Optional[InputType]
|
|
49
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
50
|
+
or a numpy array.
|
|
51
|
+
pred_data_target : Optional[InputType]
|
|
52
|
+
Prediction data target, can be a path to a folder,
|
|
53
|
+
a list of paths, or a numpy array.
|
|
54
|
+
read_source_func : Optional[Callable], default=None
|
|
55
|
+
Function to read the source data. Only used for `custom`
|
|
56
|
+
data type (see DataModel).
|
|
57
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
58
|
+
The kwargs for the read source function.
|
|
59
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
60
|
+
The image stack loader.
|
|
61
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
62
|
+
The image stack loader kwargs.
|
|
63
|
+
extension_filter : str, default=""
|
|
64
|
+
Filter for file extensions. Only used for `custom` data types
|
|
65
|
+
(see DataModel).
|
|
66
|
+
val_percentage : Optional[float]
|
|
67
|
+
Percentage of the training data to use for validation. Only
|
|
68
|
+
used if `val_data` is None.
|
|
69
|
+
val_minimum_split : int, default=5
|
|
70
|
+
Minimum number of patches or files to split from the training data for
|
|
71
|
+
validation. Only used if `val_data` is None.
|
|
72
|
+
use_in_memory : bool
|
|
73
|
+
Load data in memory dataset if possible, by default True.
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
Attributes
|
|
77
|
+
----------
|
|
78
|
+
config : DataConfig
|
|
79
|
+
Pydantic model for CAREamics data configuration.
|
|
80
|
+
data_type : str
|
|
81
|
+
Type of data, one of SupportedData.
|
|
82
|
+
batch_size : int
|
|
83
|
+
Batch size for the dataloaders.
|
|
84
|
+
use_in_memory : bool
|
|
85
|
+
Whether to load data in memory if possible.
|
|
86
|
+
extension_filter : str
|
|
87
|
+
Filter for file extensions, by default "".
|
|
88
|
+
read_source_func : Optional[Callable], default=None
|
|
89
|
+
Function to read the source data.
|
|
90
|
+
read_kwargs : Optional[dict[str, Any]], default=None
|
|
91
|
+
The kwargs for the read source function.
|
|
92
|
+
val_percentage : Optional[float]
|
|
93
|
+
Percentage of the training data to use for validation.
|
|
94
|
+
val_minimum_split : int, default=5
|
|
95
|
+
Minimum number of patches or files to split from the training data for
|
|
96
|
+
validation.
|
|
97
|
+
train_data : Optional[Any]
|
|
98
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
99
|
+
train_data_target : Optional[Any]
|
|
100
|
+
Training data target, can be a path to a folder, a list of paths, or a numpy
|
|
101
|
+
array.
|
|
102
|
+
val_data : Optional[Any]
|
|
103
|
+
Validation data, can be a path to a folder, a list of paths, or a numpy array.
|
|
104
|
+
val_data_target : Optional[Any]
|
|
105
|
+
Validation data target, can be a path to a folder, a list of paths, or a numpy
|
|
106
|
+
array.
|
|
107
|
+
pred_data : Optional[Any]
|
|
108
|
+
Prediction data, can be a path to a folder, a list of paths, or a numpy array.
|
|
109
|
+
pred_data_target : Optional[Any]
|
|
110
|
+
Prediction data target, can be a path to a folder, a list of paths, or a numpy
|
|
111
|
+
array.
|
|
112
|
+
|
|
113
|
+
Raises
|
|
114
|
+
------
|
|
115
|
+
ValueError
|
|
116
|
+
If at least one of train_data, val_data or pred_data is not provided.
|
|
117
|
+
ValueError
|
|
118
|
+
If input and target data types are not consistent.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# standard use
|
|
122
|
+
@overload
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
data_config: NGDataConfig,
|
|
126
|
+
*,
|
|
127
|
+
train_data: Optional[InputType] = None,
|
|
128
|
+
train_data_target: Optional[InputType] = None,
|
|
129
|
+
val_data: Optional[InputType] = None,
|
|
130
|
+
val_data_target: Optional[InputType] = None,
|
|
131
|
+
pred_data: Optional[InputType] = None,
|
|
132
|
+
pred_data_target: Optional[InputType] = None,
|
|
133
|
+
extension_filter: str = "",
|
|
134
|
+
val_percentage: Optional[float] = None,
|
|
135
|
+
val_minimum_split: int = 5,
|
|
136
|
+
use_in_memory: bool = True,
|
|
137
|
+
) -> None: ...
|
|
138
|
+
|
|
139
|
+
# custom read function
|
|
140
|
+
@overload
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
data_config: NGDataConfig,
|
|
144
|
+
*,
|
|
145
|
+
train_data: Optional[InputType] = None,
|
|
146
|
+
train_data_target: Optional[InputType] = None,
|
|
147
|
+
val_data: Optional[InputType] = None,
|
|
148
|
+
val_data_target: Optional[InputType] = None,
|
|
149
|
+
pred_data: Optional[InputType] = None,
|
|
150
|
+
pred_data_target: Optional[InputType] = None,
|
|
151
|
+
read_source_func: Callable,
|
|
152
|
+
read_kwargs: Optional[dict[str, Any]] = None,
|
|
153
|
+
extension_filter: str = "",
|
|
154
|
+
val_percentage: Optional[float] = None,
|
|
155
|
+
val_minimum_split: int = 5,
|
|
156
|
+
use_in_memory: bool = True,
|
|
157
|
+
) -> None: ...
|
|
158
|
+
|
|
159
|
+
@overload
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
data_config: NGDataConfig,
|
|
163
|
+
*,
|
|
164
|
+
train_data: Optional[Any] = None,
|
|
165
|
+
train_data_target: Optional[Any] = None,
|
|
166
|
+
val_data: Optional[Any] = None,
|
|
167
|
+
val_data_target: Optional[Any] = None,
|
|
168
|
+
pred_data: Optional[Any] = None,
|
|
169
|
+
pred_data_target: Optional[Any] = None,
|
|
170
|
+
image_stack_loader: ImageStackLoader,
|
|
171
|
+
image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
|
|
172
|
+
extension_filter: str = "",
|
|
173
|
+
val_percentage: Optional[float] = None,
|
|
174
|
+
val_minimum_split: int = 5,
|
|
175
|
+
use_in_memory: bool = True,
|
|
176
|
+
) -> None: ...
|
|
177
|
+
|
|
178
|
+
def __init__(
|
|
179
|
+
self,
|
|
180
|
+
data_config: NGDataConfig,
|
|
181
|
+
*,
|
|
182
|
+
train_data: Optional[Any] = None,
|
|
183
|
+
train_data_target: Optional[Any] = None,
|
|
184
|
+
val_data: Optional[Any] = None,
|
|
185
|
+
val_data_target: Optional[Any] = None,
|
|
186
|
+
pred_data: Optional[Any] = None,
|
|
187
|
+
pred_data_target: Optional[Any] = None,
|
|
188
|
+
read_source_func: Optional[Callable] = None,
|
|
189
|
+
read_kwargs: Optional[dict[str, Any]] = None,
|
|
190
|
+
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
191
|
+
image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
|
|
192
|
+
extension_filter: str = "",
|
|
193
|
+
val_percentage: Optional[float] = None,
|
|
194
|
+
val_minimum_split: int = 5,
|
|
195
|
+
use_in_memory: bool = True,
|
|
196
|
+
) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Data module for Careamics dataset initialization.
|
|
199
|
+
|
|
200
|
+
Create a lightning datamodule that handles creating datasets for training,
|
|
201
|
+
validation, and prediction.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
data_config : NGDataConfig
|
|
206
|
+
Pydantic model for CAREamics data configuration.
|
|
207
|
+
train_data : Optional[InputType]
|
|
208
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
209
|
+
train_data_target : Optional[InputType]
|
|
210
|
+
Training data target, can be a path to a folder,
|
|
211
|
+
a list of paths, or a numpy array.
|
|
212
|
+
val_data : Optional[InputType]
|
|
213
|
+
Validation data, can be a path to a folder,
|
|
214
|
+
a list of paths, or a numpy array.
|
|
215
|
+
val_data_target : Optional[InputType]
|
|
216
|
+
Validation data target, can be a path to a folder,
|
|
217
|
+
a list of paths, or a numpy array.
|
|
218
|
+
pred_data : Optional[InputType]
|
|
219
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
220
|
+
or a numpy array.
|
|
221
|
+
pred_data_target : Optional[InputType]
|
|
222
|
+
Prediction data target, can be a path to a folder,
|
|
223
|
+
a list of paths, or a numpy array.
|
|
224
|
+
read_source_func : Optional[Callable]
|
|
225
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
226
|
+
data type (see DataModel).
|
|
227
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
228
|
+
The kwargs for the read source function.
|
|
229
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
230
|
+
The image stack loader.
|
|
231
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
232
|
+
The image stack loader kwargs.
|
|
233
|
+
extension_filter : str
|
|
234
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
235
|
+
(see DataModel).
|
|
236
|
+
val_percentage : Optional[float]
|
|
237
|
+
Percentage of the training data to use for validation. Only
|
|
238
|
+
used if `val_data` is None.
|
|
239
|
+
val_minimum_split : int
|
|
240
|
+
Minimum number of patches or files to split from the training data for
|
|
241
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
242
|
+
use_in_memory : bool
|
|
243
|
+
Load data in memory dataset if possible, by default True.
|
|
244
|
+
"""
|
|
245
|
+
super().__init__()
|
|
246
|
+
|
|
247
|
+
if train_data is None and val_data is None and pred_data is None:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
"At least one of train_data, val_data or pred_data must be provided."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
self.config: NGDataConfig = data_config
|
|
253
|
+
self.data_type: str = data_config.data_type
|
|
254
|
+
self.batch_size: int = data_config.batch_size
|
|
255
|
+
self.use_in_memory: bool = use_in_memory
|
|
256
|
+
self.extension_filter: str = extension_filter
|
|
257
|
+
self.read_source_func = read_source_func
|
|
258
|
+
self.read_kwargs = read_kwargs
|
|
259
|
+
self.image_stack_loader = image_stack_loader
|
|
260
|
+
self.image_stack_loader_kwargs = image_stack_loader_kwargs
|
|
261
|
+
|
|
262
|
+
# TODO: implement the validation split logic
|
|
263
|
+
self.val_percentage = val_percentage
|
|
264
|
+
self.val_minimum_split = val_minimum_split
|
|
265
|
+
if self.val_percentage is not None:
|
|
266
|
+
raise NotImplementedError("Validation split not implemented")
|
|
267
|
+
|
|
268
|
+
self.train_data, self.train_data_target = self._initialize_data_pair(
|
|
269
|
+
train_data, train_data_target
|
|
270
|
+
)
|
|
271
|
+
self.val_data, self.val_data_target = self._initialize_data_pair(
|
|
272
|
+
val_data, val_data_target
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# The pred_data_target can be needed to count metrics on the prediction
|
|
276
|
+
self.pred_data, self.pred_data_target = self._initialize_data_pair(
|
|
277
|
+
pred_data, pred_data_target
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def _validate_input_target_type_consistency(
|
|
281
|
+
self,
|
|
282
|
+
input_data: InputType,
|
|
283
|
+
target_data: Optional[InputType],
|
|
284
|
+
) -> None:
|
|
285
|
+
"""Validate if the input and target data types are consistent.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
input_data : InputType
|
|
290
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
291
|
+
target_data : Optional[InputType]
|
|
292
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
293
|
+
array.
|
|
294
|
+
"""
|
|
295
|
+
if input_data is not None and target_data is not None:
|
|
296
|
+
if not isinstance(input_data, type(target_data)):
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"Inputs for input and target must be of the same type or None. "
|
|
299
|
+
f"Got {type(input_data)} and {type(target_data)}."
|
|
300
|
+
)
|
|
301
|
+
if isinstance(input_data, list) and isinstance(target_data, list):
|
|
302
|
+
if len(input_data) != len(target_data):
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Inputs and targets must have the same length. "
|
|
305
|
+
f"Got {len(input_data)} and {len(target_data)}."
|
|
306
|
+
)
|
|
307
|
+
if not isinstance(input_data[0], type(target_data[0])):
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"Inputs and targets must have the same type. "
|
|
310
|
+
f"Got {type(input_data[0])} and {type(target_data[0])}."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def _list_files_in_directory(
|
|
314
|
+
self,
|
|
315
|
+
input_data,
|
|
316
|
+
target_data=None,
|
|
317
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
318
|
+
"""List files from input and target directories.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
input_data : InputType
|
|
323
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
324
|
+
target_data : Optional[InputType]
|
|
325
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
326
|
+
array.
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
(list[Path], Optional[list[Path]])
|
|
331
|
+
A tuple containing lists of file paths for input and target data.
|
|
332
|
+
If target_data is None, the second element will be None.
|
|
333
|
+
"""
|
|
334
|
+
input_data = Path(input_data)
|
|
335
|
+
input_files = list_files(input_data, self.data_type, self.extension_filter)
|
|
336
|
+
if target_data is None:
|
|
337
|
+
return input_files, None
|
|
338
|
+
else:
|
|
339
|
+
target_data = Path(target_data)
|
|
340
|
+
target_files = list_files(
|
|
341
|
+
target_data, self.data_type, self.extension_filter
|
|
342
|
+
)
|
|
343
|
+
validate_source_target_files(input_files, target_files)
|
|
344
|
+
return input_files, target_files
|
|
345
|
+
|
|
346
|
+
def _convert_paths_to_pathlib(
|
|
347
|
+
self,
|
|
348
|
+
input_data,
|
|
349
|
+
target_data=None,
|
|
350
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
351
|
+
"""Create a list of file paths from the input and target data.
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
input_data : InputType
|
|
356
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
357
|
+
target_data : Optional[InputType]
|
|
358
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
359
|
+
array.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
(list[Path], Optional[list[Path]])
|
|
364
|
+
A tuple containing lists of file paths for input and target data.
|
|
365
|
+
If target_data is None, the second element will be None.
|
|
366
|
+
"""
|
|
367
|
+
input_files = [
|
|
368
|
+
Path(item) if isinstance(item, str) else item for item in input_data
|
|
369
|
+
]
|
|
370
|
+
if target_data is None:
|
|
371
|
+
return input_files, None
|
|
372
|
+
else:
|
|
373
|
+
target_files = [
|
|
374
|
+
Path(item) if isinstance(item, str) else item for item in target_data
|
|
375
|
+
]
|
|
376
|
+
validate_source_target_files(input_files, target_files)
|
|
377
|
+
return input_files, target_files
|
|
378
|
+
|
|
379
|
+
def _validate_array_input(
|
|
380
|
+
self,
|
|
381
|
+
input_data: InputType,
|
|
382
|
+
target_data: Optional[InputType],
|
|
383
|
+
) -> tuple[Any, Any]:
|
|
384
|
+
"""Validate if the input data is a numpy array.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
input_data : InputType
|
|
389
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
390
|
+
target_data : Optional[InputType]
|
|
391
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
392
|
+
array.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
(Any, Any)
|
|
397
|
+
A tuple containing the input and target.
|
|
398
|
+
"""
|
|
399
|
+
if isinstance(input_data, np.ndarray):
|
|
400
|
+
input_array = [input_data]
|
|
401
|
+
target_array = [target_data] if target_data is not None else None
|
|
402
|
+
return input_array, target_array
|
|
403
|
+
elif isinstance(input_data, list):
|
|
404
|
+
return input_data, target_data
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def _validate_path_input(
|
|
411
|
+
self, input_data: InputType, target_data: Optional[InputType]
|
|
412
|
+
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
413
|
+
"""Validate if the input data is a path or a list of paths.
|
|
414
|
+
|
|
415
|
+
Parameters
|
|
416
|
+
----------
|
|
417
|
+
input_data : InputType
|
|
418
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
419
|
+
target_data : Optional[InputType]
|
|
420
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
421
|
+
array.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
(list[Path], Optional[list[Path]])
|
|
426
|
+
A tuple containing lists of file paths for input and target data.
|
|
427
|
+
If target_data is None, the second element will be None.
|
|
428
|
+
"""
|
|
429
|
+
if isinstance(input_data, str | Path):
|
|
430
|
+
if target_data is not None:
|
|
431
|
+
assert isinstance(target_data, str | Path)
|
|
432
|
+
input_list, target_list = self._list_files_in_directory(
|
|
433
|
+
input_data, target_data
|
|
434
|
+
)
|
|
435
|
+
return input_list, target_list
|
|
436
|
+
elif isinstance(input_data, list):
|
|
437
|
+
if target_data is not None:
|
|
438
|
+
assert isinstance(target_data, list)
|
|
439
|
+
input_list, target_list = self._convert_paths_to_pathlib(
|
|
440
|
+
input_data, target_data
|
|
441
|
+
)
|
|
442
|
+
return input_list, target_list
|
|
443
|
+
else:
|
|
444
|
+
raise ValueError(
|
|
445
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
|
|
449
|
+
"""Convert custom input data to a list of file paths.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
input_data : InputType
|
|
454
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
455
|
+
target_data : Optional[InputType]
|
|
456
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
457
|
+
array.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
(Any, Any)
|
|
462
|
+
A tuple containing lists of file paths for input and target data.
|
|
463
|
+
If target_data is None, the second element will be None.
|
|
464
|
+
"""
|
|
465
|
+
if self.image_stack_loader is not None:
|
|
466
|
+
return input_data, target_data
|
|
467
|
+
elif isinstance(input_data, str | Path):
|
|
468
|
+
if target_data is not None:
|
|
469
|
+
assert isinstance(target_data, str | Path)
|
|
470
|
+
input_list, target_list = self._list_files_in_directory(
|
|
471
|
+
input_data, target_data
|
|
472
|
+
)
|
|
473
|
+
return input_list, target_list
|
|
474
|
+
elif isinstance(input_data, list):
|
|
475
|
+
if isinstance(input_data[0], str | Path):
|
|
476
|
+
if target_data is not None:
|
|
477
|
+
assert isinstance(target_data, list)
|
|
478
|
+
input_list, target_list = self._convert_paths_to_pathlib(
|
|
479
|
+
input_data, target_data
|
|
480
|
+
)
|
|
481
|
+
return input_list, target_list
|
|
482
|
+
else:
|
|
483
|
+
raise ValueError(
|
|
484
|
+
f"If using {self.data_type}, pass a custom "
|
|
485
|
+
f"image_stack_loader or read_source_func"
|
|
486
|
+
)
|
|
487
|
+
return input_data, target_data
|
|
488
|
+
|
|
489
|
+
def _initialize_data_pair(
|
|
490
|
+
self,
|
|
491
|
+
input_data: Optional[InputType],
|
|
492
|
+
target_data: Optional[InputType],
|
|
493
|
+
) -> tuple[Any, Any]:
|
|
494
|
+
"""
|
|
495
|
+
Initialize a pair of input and target data.
|
|
496
|
+
|
|
497
|
+
Parameters
|
|
498
|
+
----------
|
|
499
|
+
input_data : InputType
|
|
500
|
+
Input data, can be None, a path to a folder, a list of paths, or a numpy
|
|
501
|
+
array.
|
|
502
|
+
target_data : Optional[InputType]
|
|
503
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
504
|
+
array.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
(list of numpy.ndarray or list of pathlib.Path, None or list of numpy.ndarray or
|
|
509
|
+
list of pathlib.Path)
|
|
510
|
+
A tuple containing the initialized input and target data. For file paths,
|
|
511
|
+
returns lists of Path objects. For numpy arrays, returns the arrays
|
|
512
|
+
directly.
|
|
513
|
+
"""
|
|
514
|
+
if input_data is None:
|
|
515
|
+
return None, None
|
|
516
|
+
|
|
517
|
+
self._validate_input_target_type_consistency(input_data, target_data)
|
|
518
|
+
|
|
519
|
+
if self.data_type == SupportedData.ARRAY:
|
|
520
|
+
if isinstance(input_data, np.ndarray):
|
|
521
|
+
return self._validate_array_input(input_data, target_data)
|
|
522
|
+
elif isinstance(input_data, list):
|
|
523
|
+
if isinstance(input_data[0], np.ndarray):
|
|
524
|
+
return self._validate_array_input(input_data, target_data)
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
f"Unsupported input type for {self.data_type}: "
|
|
528
|
+
f"{type(input_data[0])}"
|
|
529
|
+
)
|
|
530
|
+
else:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
533
|
+
)
|
|
534
|
+
elif self.data_type in (SupportedData.TIFF, SupportedData.CZI):
|
|
535
|
+
if isinstance(input_data, str | Path):
|
|
536
|
+
return self._validate_path_input(input_data, target_data)
|
|
537
|
+
elif isinstance(input_data, list):
|
|
538
|
+
if isinstance(input_data[0], str | Path):
|
|
539
|
+
return self._validate_path_input(input_data, target_data)
|
|
540
|
+
else:
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"Unsupported input type for {self.data_type}: "
|
|
543
|
+
f"{type(input_data[0])}"
|
|
544
|
+
)
|
|
545
|
+
else:
|
|
546
|
+
raise ValueError(
|
|
547
|
+
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
548
|
+
)
|
|
549
|
+
elif self.data_type == SupportedData.CUSTOM:
|
|
550
|
+
return self._validate_custom_input(input_data, target_data)
|
|
551
|
+
else:
|
|
552
|
+
raise NotImplementedError(f"Unsupported data type: {self.data_type}")
|
|
553
|
+
|
|
554
|
+
def setup(self, stage: str) -> None:
|
|
555
|
+
"""
|
|
556
|
+
Setup datasets.
|
|
557
|
+
|
|
558
|
+
Lightning hook that is called at the beginning of fit (train + validate),
|
|
559
|
+
validate, test, or predict. Creates the datasets for a given stage.
|
|
560
|
+
|
|
561
|
+
Parameters
|
|
562
|
+
----------
|
|
563
|
+
stage : str
|
|
564
|
+
The stage to set up datasets for.
|
|
565
|
+
Is either 'fit', 'validate', 'test', or 'predict'.
|
|
566
|
+
|
|
567
|
+
Raises
|
|
568
|
+
------
|
|
569
|
+
NotImplementedError
|
|
570
|
+
If stage is not one of "fit", "validate" or "predict".
|
|
571
|
+
"""
|
|
572
|
+
if stage == "fit":
|
|
573
|
+
self.train_dataset = create_dataset(
|
|
574
|
+
mode=Mode.TRAINING,
|
|
575
|
+
inputs=self.train_data,
|
|
576
|
+
targets=self.train_data_target,
|
|
577
|
+
config=self.config,
|
|
578
|
+
in_memory=self.use_in_memory,
|
|
579
|
+
read_func=self.read_source_func,
|
|
580
|
+
read_kwargs=self.read_kwargs,
|
|
581
|
+
image_stack_loader=self.image_stack_loader,
|
|
582
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
583
|
+
)
|
|
584
|
+
# TODO: ugly, need to find a better solution
|
|
585
|
+
self.stats = self.train_dataset.input_stats
|
|
586
|
+
self.config.set_means_and_stds(
|
|
587
|
+
self.train_dataset.input_stats.means,
|
|
588
|
+
self.train_dataset.input_stats.stds,
|
|
589
|
+
self.train_dataset.target_stats.means,
|
|
590
|
+
self.train_dataset.target_stats.stds,
|
|
591
|
+
)
|
|
592
|
+
self.val_dataset = create_dataset(
|
|
593
|
+
mode=Mode.VALIDATING,
|
|
594
|
+
inputs=self.val_data,
|
|
595
|
+
targets=self.val_data_target,
|
|
596
|
+
config=self.config,
|
|
597
|
+
in_memory=self.use_in_memory,
|
|
598
|
+
read_func=self.read_source_func,
|
|
599
|
+
read_kwargs=self.read_kwargs,
|
|
600
|
+
image_stack_loader=self.image_stack_loader,
|
|
601
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
602
|
+
)
|
|
603
|
+
elif stage == "validate":
|
|
604
|
+
self.val_dataset = create_dataset(
|
|
605
|
+
mode=Mode.VALIDATING,
|
|
606
|
+
inputs=self.val_data,
|
|
607
|
+
targets=self.val_data_target,
|
|
608
|
+
config=self.config,
|
|
609
|
+
in_memory=self.use_in_memory,
|
|
610
|
+
read_func=self.read_source_func,
|
|
611
|
+
read_kwargs=self.read_kwargs,
|
|
612
|
+
image_stack_loader=self.image_stack_loader,
|
|
613
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
614
|
+
)
|
|
615
|
+
self.stats = self.val_dataset.input_stats
|
|
616
|
+
elif stage == "predict":
|
|
617
|
+
self.predict_dataset = create_dataset(
|
|
618
|
+
mode=Mode.PREDICTING,
|
|
619
|
+
inputs=self.pred_data,
|
|
620
|
+
targets=self.pred_data_target,
|
|
621
|
+
config=self.config,
|
|
622
|
+
in_memory=self.use_in_memory,
|
|
623
|
+
read_func=self.read_source_func,
|
|
624
|
+
read_kwargs=self.read_kwargs,
|
|
625
|
+
image_stack_loader=self.image_stack_loader,
|
|
626
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
627
|
+
)
|
|
628
|
+
self.stats = self.predict_dataset.input_stats
|
|
629
|
+
else:
|
|
630
|
+
raise NotImplementedError(f"Stage {stage} not implemented")
|
|
631
|
+
|
|
632
|
+
def train_dataloader(self) -> DataLoader:
|
|
633
|
+
"""
|
|
634
|
+
Create a dataloader for training.
|
|
635
|
+
|
|
636
|
+
Returns
|
|
637
|
+
-------
|
|
638
|
+
DataLoader
|
|
639
|
+
Training dataloader.
|
|
640
|
+
"""
|
|
641
|
+
return DataLoader(
|
|
642
|
+
self.train_dataset,
|
|
643
|
+
batch_size=self.batch_size,
|
|
644
|
+
collate_fn=default_collate,
|
|
645
|
+
**self.config.train_dataloader_params,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def val_dataloader(self) -> DataLoader:
|
|
649
|
+
"""
|
|
650
|
+
Create a dataloader for validation.
|
|
651
|
+
|
|
652
|
+
Returns
|
|
653
|
+
-------
|
|
654
|
+
DataLoader
|
|
655
|
+
Validation dataloader.
|
|
656
|
+
"""
|
|
657
|
+
return DataLoader(
|
|
658
|
+
self.val_dataset,
|
|
659
|
+
batch_size=self.batch_size,
|
|
660
|
+
collate_fn=default_collate,
|
|
661
|
+
**self.config.val_dataloader_params,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
def predict_dataloader(self) -> DataLoader:
|
|
665
|
+
"""
|
|
666
|
+
Create a dataloader for prediction.
|
|
667
|
+
|
|
668
|
+
Returns
|
|
669
|
+
-------
|
|
670
|
+
DataLoader
|
|
671
|
+
Prediction dataloader.
|
|
672
|
+
"""
|
|
673
|
+
return DataLoader(
|
|
674
|
+
self.predict_dataset,
|
|
675
|
+
batch_size=self.batch_size,
|
|
676
|
+
collate_fn=default_collate,
|
|
677
|
+
**self.config.test_dataloader_params,
|
|
678
|
+
)
|