careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/models/unet.py +16 -10
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
"""Next-Generation CAREamics DataModule."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
1
4
|
from pathlib import Path
|
|
2
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Optional, Union, overload
|
|
3
6
|
|
|
4
7
|
import numpy as np
|
|
5
8
|
import pytorch_lightning as L
|
|
@@ -7,7 +10,7 @@ from numpy.typing import NDArray
|
|
|
7
10
|
from torch.utils.data import DataLoader
|
|
8
11
|
from torch.utils.data._utils.collate import default_collate
|
|
9
12
|
|
|
10
|
-
from careamics.config.data import
|
|
13
|
+
from careamics.config.data.ng_data_model import NGDataConfig
|
|
11
14
|
from careamics.config.support import SupportedData
|
|
12
15
|
from careamics.dataset.dataset_utils import list_files, validate_source_target_files
|
|
13
16
|
from careamics.dataset_ng.dataset import Mode
|
|
@@ -18,16 +21,108 @@ from careamics.utils import get_logger
|
|
|
18
21
|
logger = get_logger(__name__)
|
|
19
22
|
|
|
20
23
|
ItemType = Union[Path, str, NDArray[Any]]
|
|
24
|
+
"""Type of input items passed to the dataset."""
|
|
25
|
+
|
|
21
26
|
InputType = Union[ItemType, list[ItemType], None]
|
|
27
|
+
"""Type of input data passed to the dataset."""
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
class CareamicsDataModule(L.LightningDataModule):
|
|
25
|
-
"""Data module for Careamics dataset.
|
|
26
|
-
|
|
31
|
+
"""Data module for Careamics dataset.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : DataConfig
|
|
36
|
+
Pydantic model for CAREamics data configuration.
|
|
37
|
+
train_data : Optional[InputType]
|
|
38
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
39
|
+
train_data_target : Optional[InputType]
|
|
40
|
+
Training data target, can be a path to a folder,
|
|
41
|
+
a list of paths, or a numpy array.
|
|
42
|
+
val_data : Optional[InputType]
|
|
43
|
+
Validation data, can be a path to a folder,
|
|
44
|
+
a list of paths, or a numpy array.
|
|
45
|
+
val_data_target : Optional[InputType]
|
|
46
|
+
Validation data target, can be a path to a folder,
|
|
47
|
+
a list of paths, or a numpy array.
|
|
48
|
+
pred_data : Optional[InputType]
|
|
49
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
50
|
+
or a numpy array.
|
|
51
|
+
pred_data_target : Optional[InputType]
|
|
52
|
+
Prediction data target, can be a path to a folder,
|
|
53
|
+
a list of paths, or a numpy array.
|
|
54
|
+
read_source_func : Optional[Callable], default=None
|
|
55
|
+
Function to read the source data. Only used for `custom`
|
|
56
|
+
data type (see DataModel).
|
|
57
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
58
|
+
The kwargs for the read source function.
|
|
59
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
60
|
+
The image stack loader.
|
|
61
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
62
|
+
The image stack loader kwargs.
|
|
63
|
+
extension_filter : str, default=""
|
|
64
|
+
Filter for file extensions. Only used for `custom` data types
|
|
65
|
+
(see DataModel).
|
|
66
|
+
val_percentage : Optional[float]
|
|
67
|
+
Percentage of the training data to use for validation. Only
|
|
68
|
+
used if `val_data` is None.
|
|
69
|
+
val_minimum_split : int, default=5
|
|
70
|
+
Minimum number of patches or files to split from the training data for
|
|
71
|
+
validation. Only used if `val_data` is None.
|
|
72
|
+
use_in_memory : bool
|
|
73
|
+
Load data in memory dataset if possible, by default True.
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
Attributes
|
|
77
|
+
----------
|
|
78
|
+
config : DataConfig
|
|
79
|
+
Pydantic model for CAREamics data configuration.
|
|
80
|
+
data_type : str
|
|
81
|
+
Type of data, one of SupportedData.
|
|
82
|
+
batch_size : int
|
|
83
|
+
Batch size for the dataloaders.
|
|
84
|
+
use_in_memory : bool
|
|
85
|
+
Whether to load data in memory if possible.
|
|
86
|
+
extension_filter : str
|
|
87
|
+
Filter for file extensions, by default "".
|
|
88
|
+
read_source_func : Optional[Callable], default=None
|
|
89
|
+
Function to read the source data.
|
|
90
|
+
read_kwargs : Optional[dict[str, Any]], default=None
|
|
91
|
+
The kwargs for the read source function.
|
|
92
|
+
val_percentage : Optional[float]
|
|
93
|
+
Percentage of the training data to use for validation.
|
|
94
|
+
val_minimum_split : int, default=5
|
|
95
|
+
Minimum number of patches or files to split from the training data for
|
|
96
|
+
validation.
|
|
97
|
+
train_data : Optional[Any]
|
|
98
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
99
|
+
train_data_target : Optional[Any]
|
|
100
|
+
Training data target, can be a path to a folder, a list of paths, or a numpy
|
|
101
|
+
array.
|
|
102
|
+
val_data : Optional[Any]
|
|
103
|
+
Validation data, can be a path to a folder, a list of paths, or a numpy array.
|
|
104
|
+
val_data_target : Optional[Any]
|
|
105
|
+
Validation data target, can be a path to a folder, a list of paths, or a numpy
|
|
106
|
+
array.
|
|
107
|
+
pred_data : Optional[Any]
|
|
108
|
+
Prediction data, can be a path to a folder, a list of paths, or a numpy array.
|
|
109
|
+
pred_data_target : Optional[Any]
|
|
110
|
+
Prediction data target, can be a path to a folder, a list of paths, or a numpy
|
|
111
|
+
array.
|
|
112
|
+
|
|
113
|
+
Raises
|
|
114
|
+
------
|
|
115
|
+
ValueError
|
|
116
|
+
If at least one of train_data, val_data or pred_data is not provided.
|
|
117
|
+
ValueError
|
|
118
|
+
If input and target data types are not consistent.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# standard use
|
|
27
122
|
@overload
|
|
28
123
|
def __init__(
|
|
29
124
|
self,
|
|
30
|
-
data_config:
|
|
125
|
+
data_config: NGDataConfig,
|
|
31
126
|
*,
|
|
32
127
|
train_data: Optional[InputType] = None,
|
|
33
128
|
train_data_target: Optional[InputType] = None,
|
|
@@ -41,10 +136,11 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
41
136
|
use_in_memory: bool = True,
|
|
42
137
|
) -> None: ...
|
|
43
138
|
|
|
139
|
+
# custom read function
|
|
44
140
|
@overload
|
|
45
141
|
def __init__(
|
|
46
142
|
self,
|
|
47
|
-
data_config:
|
|
143
|
+
data_config: NGDataConfig,
|
|
48
144
|
*,
|
|
49
145
|
train_data: Optional[InputType] = None,
|
|
50
146
|
train_data_target: Optional[InputType] = None,
|
|
@@ -63,7 +159,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
63
159
|
@overload
|
|
64
160
|
def __init__(
|
|
65
161
|
self,
|
|
66
|
-
data_config:
|
|
162
|
+
data_config: NGDataConfig,
|
|
67
163
|
*,
|
|
68
164
|
train_data: Optional[Any] = None,
|
|
69
165
|
train_data_target: Optional[Any] = None,
|
|
@@ -81,7 +177,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
81
177
|
|
|
82
178
|
def __init__(
|
|
83
179
|
self,
|
|
84
|
-
data_config:
|
|
180
|
+
data_config: NGDataConfig,
|
|
85
181
|
*,
|
|
86
182
|
train_data: Optional[Any] = None,
|
|
87
183
|
train_data_target: Optional[Any] = None,
|
|
@@ -106,7 +202,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
106
202
|
|
|
107
203
|
Parameters
|
|
108
204
|
----------
|
|
109
|
-
data_config :
|
|
205
|
+
data_config : NGDataConfig
|
|
110
206
|
Pydantic model for CAREamics data configuration.
|
|
111
207
|
train_data : Optional[InputType]
|
|
112
208
|
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
@@ -153,7 +249,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
153
249
|
"At least one of train_data, val_data or pred_data must be provided."
|
|
154
250
|
)
|
|
155
251
|
|
|
156
|
-
self.config:
|
|
252
|
+
self.config: NGDataConfig = data_config
|
|
157
253
|
self.data_type: str = data_config.data_type
|
|
158
254
|
self.batch_size: int = data_config.batch_size
|
|
159
255
|
self.use_in_memory: bool = use_in_memory
|
|
@@ -186,7 +282,16 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
186
282
|
input_data: InputType,
|
|
187
283
|
target_data: Optional[InputType],
|
|
188
284
|
) -> None:
|
|
189
|
-
"""Validate if the input and target data types are consistent.
|
|
285
|
+
"""Validate if the input and target data types are consistent.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
input_data : InputType
|
|
290
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
291
|
+
target_data : Optional[InputType]
|
|
292
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
293
|
+
array.
|
|
294
|
+
"""
|
|
190
295
|
if input_data is not None and target_data is not None:
|
|
191
296
|
if not isinstance(input_data, type(target_data)):
|
|
192
297
|
raise ValueError(
|
|
@@ -210,7 +315,22 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
210
315
|
input_data,
|
|
211
316
|
target_data=None,
|
|
212
317
|
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
213
|
-
"""List files from input and target directories.
|
|
318
|
+
"""List files from input and target directories.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
input_data : InputType
|
|
323
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
324
|
+
target_data : Optional[InputType]
|
|
325
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
326
|
+
array.
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
(list[Path], Optional[list[Path]])
|
|
331
|
+
A tuple containing lists of file paths for input and target data.
|
|
332
|
+
If target_data is None, the second element will be None.
|
|
333
|
+
"""
|
|
214
334
|
input_data = Path(input_data)
|
|
215
335
|
input_files = list_files(input_data, self.data_type, self.extension_filter)
|
|
216
336
|
if target_data is None:
|
|
@@ -228,7 +348,22 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
228
348
|
input_data,
|
|
229
349
|
target_data=None,
|
|
230
350
|
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
231
|
-
"""Create a list of file paths from the input and target data.
|
|
351
|
+
"""Create a list of file paths from the input and target data.
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
input_data : InputType
|
|
356
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
357
|
+
target_data : Optional[InputType]
|
|
358
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
359
|
+
array.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
(list[Path], Optional[list[Path]])
|
|
364
|
+
A tuple containing lists of file paths for input and target data.
|
|
365
|
+
If target_data is None, the second element will be None.
|
|
366
|
+
"""
|
|
232
367
|
input_files = [
|
|
233
368
|
Path(item) if isinstance(item, str) else item for item in input_data
|
|
234
369
|
]
|
|
@@ -246,7 +381,21 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
246
381
|
input_data: InputType,
|
|
247
382
|
target_data: Optional[InputType],
|
|
248
383
|
) -> tuple[Any, Any]:
|
|
249
|
-
"""Validate if the input data is a numpy array.
|
|
384
|
+
"""Validate if the input data is a numpy array.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
input_data : InputType
|
|
389
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
390
|
+
target_data : Optional[InputType]
|
|
391
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
392
|
+
array.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
(Any, Any)
|
|
397
|
+
A tuple containing the input and target.
|
|
398
|
+
"""
|
|
250
399
|
if isinstance(input_data, np.ndarray):
|
|
251
400
|
input_array = [input_data]
|
|
252
401
|
target_array = [target_data] if target_data is not None else None
|
|
@@ -261,9 +410,25 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
261
410
|
def _validate_path_input(
|
|
262
411
|
self, input_data: InputType, target_data: Optional[InputType]
|
|
263
412
|
) -> tuple[list[Path], Optional[list[Path]]]:
|
|
264
|
-
if
|
|
413
|
+
"""Validate if the input data is a path or a list of paths.
|
|
414
|
+
|
|
415
|
+
Parameters
|
|
416
|
+
----------
|
|
417
|
+
input_data : InputType
|
|
418
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
419
|
+
target_data : Optional[InputType]
|
|
420
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
421
|
+
array.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
(list[Path], Optional[list[Path]])
|
|
426
|
+
A tuple containing lists of file paths for input and target data.
|
|
427
|
+
If target_data is None, the second element will be None.
|
|
428
|
+
"""
|
|
429
|
+
if isinstance(input_data, str | Path):
|
|
265
430
|
if target_data is not None:
|
|
266
|
-
assert isinstance(target_data,
|
|
431
|
+
assert isinstance(target_data, str | Path)
|
|
267
432
|
input_list, target_list = self._list_files_in_directory(
|
|
268
433
|
input_data, target_data
|
|
269
434
|
)
|
|
@@ -281,17 +446,33 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
281
446
|
)
|
|
282
447
|
|
|
283
448
|
def _validate_custom_input(self, input_data, target_data) -> tuple[Any, Any]:
|
|
449
|
+
"""Convert custom input data to a list of file paths.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
input_data : InputType
|
|
454
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
455
|
+
target_data : Optional[InputType]
|
|
456
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
457
|
+
array.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
(Any, Any)
|
|
462
|
+
A tuple containing lists of file paths for input and target data.
|
|
463
|
+
If target_data is None, the second element will be None.
|
|
464
|
+
"""
|
|
284
465
|
if self.image_stack_loader is not None:
|
|
285
466
|
return input_data, target_data
|
|
286
|
-
elif isinstance(input_data,
|
|
467
|
+
elif isinstance(input_data, str | Path):
|
|
287
468
|
if target_data is not None:
|
|
288
|
-
assert isinstance(target_data,
|
|
469
|
+
assert isinstance(target_data, str | Path)
|
|
289
470
|
input_list, target_list = self._list_files_in_directory(
|
|
290
471
|
input_data, target_data
|
|
291
472
|
)
|
|
292
473
|
return input_list, target_list
|
|
293
474
|
elif isinstance(input_data, list):
|
|
294
|
-
if isinstance(input_data[0],
|
|
475
|
+
if isinstance(input_data[0], str | Path):
|
|
295
476
|
if target_data is not None:
|
|
296
477
|
assert isinstance(target_data, list)
|
|
297
478
|
input_list, target_list = self._convert_paths_to_pathlib(
|
|
@@ -313,13 +494,22 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
313
494
|
"""
|
|
314
495
|
Initialize a pair of input and target data.
|
|
315
496
|
|
|
497
|
+
Parameters
|
|
498
|
+
----------
|
|
499
|
+
input_data : InputType
|
|
500
|
+
Input data, can be None, a path to a folder, a list of paths, or a numpy
|
|
501
|
+
array.
|
|
502
|
+
target_data : Optional[InputType]
|
|
503
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
504
|
+
array.
|
|
505
|
+
|
|
316
506
|
Returns
|
|
317
507
|
-------
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
A tuple containing the initialized input and target data.
|
|
321
|
-
|
|
322
|
-
|
|
508
|
+
(list of numpy.ndarray or list of pathlib.Path, None or list of numpy.ndarray or
|
|
509
|
+
list of pathlib.Path)
|
|
510
|
+
A tuple containing the initialized input and target data. For file paths,
|
|
511
|
+
returns lists of Path objects. For numpy arrays, returns the arrays
|
|
512
|
+
directly.
|
|
323
513
|
"""
|
|
324
514
|
if input_data is None:
|
|
325
515
|
return None, None
|
|
@@ -341,11 +531,11 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
341
531
|
raise ValueError(
|
|
342
532
|
f"Unsupported input type for {self.data_type}: {type(input_data)}"
|
|
343
533
|
)
|
|
344
|
-
elif self.data_type
|
|
345
|
-
if isinstance(input_data,
|
|
534
|
+
elif self.data_type in (SupportedData.TIFF, SupportedData.CZI):
|
|
535
|
+
if isinstance(input_data, str | Path):
|
|
346
536
|
return self._validate_path_input(input_data, target_data)
|
|
347
537
|
elif isinstance(input_data, list):
|
|
348
|
-
if isinstance(input_data[0],
|
|
538
|
+
if isinstance(input_data[0], str | Path):
|
|
349
539
|
return self._validate_path_input(input_data, target_data)
|
|
350
540
|
else:
|
|
351
541
|
raise ValueError(
|
|
@@ -484,5 +674,5 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
484
674
|
self.predict_dataset,
|
|
485
675
|
batch_size=self.batch_size,
|
|
486
676
|
collate_fn=default_collate,
|
|
487
|
-
|
|
677
|
+
**self.config.test_dataloader_params,
|
|
488
678
|
)
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
"""CARE Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Union
|
|
2
5
|
|
|
3
6
|
from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
|
|
4
7
|
from careamics.config.algorithms.n2n_algorithm_model import N2NAlgorithm
|
|
@@ -13,12 +16,27 @@ logger = get_logger(__name__)
|
|
|
13
16
|
|
|
14
17
|
|
|
15
18
|
class CAREModule(UnetModule):
|
|
16
|
-
"""CAREamics PyTorch Lightning module for CARE algorithm.
|
|
19
|
+
"""CAREamics PyTorch Lightning module for CARE algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : CAREAlgorithm or dict
|
|
24
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
17
27
|
|
|
18
28
|
def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate CARE DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : CAREAlgorithm or dict
|
|
34
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
|
|
35
|
+
a dictionary.
|
|
36
|
+
"""
|
|
19
37
|
super().__init__(algorithm_config)
|
|
20
38
|
assert isinstance(
|
|
21
|
-
algorithm_config,
|
|
39
|
+
algorithm_config, CAREAlgorithm | N2NAlgorithm
|
|
22
40
|
), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
|
|
23
41
|
loss = algorithm_config.loss
|
|
24
42
|
if loss == SupportedLoss.MAE:
|
|
@@ -33,7 +51,20 @@ class CAREModule(UnetModule):
|
|
|
33
51
|
batch: tuple[ImageRegionData, ImageRegionData],
|
|
34
52
|
batch_idx: Any,
|
|
35
53
|
) -> Any:
|
|
36
|
-
"""Training step for CARE module.
|
|
54
|
+
"""Training step for CARE module.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
59
|
+
A tuple containing the input data and the target data.
|
|
60
|
+
batch_idx : Any
|
|
61
|
+
The index of the current batch in the training loop.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Any
|
|
66
|
+
The loss value computed for the current batch.
|
|
67
|
+
"""
|
|
37
68
|
# TODO: add validation to determine if target is initialized
|
|
38
69
|
x, target = batch[0], batch[1]
|
|
39
70
|
|
|
@@ -49,7 +80,15 @@ class CAREModule(UnetModule):
|
|
|
49
80
|
batch: tuple[ImageRegionData, ImageRegionData],
|
|
50
81
|
batch_idx: Any,
|
|
51
82
|
) -> None:
|
|
52
|
-
"""Validation step for CARE module.
|
|
83
|
+
"""Validation step for CARE module.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
88
|
+
A tuple containing the input data and the target data.
|
|
89
|
+
batch_idx : Any
|
|
90
|
+
The index of the current batch in the training loop.
|
|
91
|
+
"""
|
|
53
92
|
x, target = batch[0], batch[1]
|
|
54
93
|
|
|
55
94
|
prediction = self.model(x.data)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Noise2Void Lightning DataModule."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Union
|
|
2
4
|
|
|
3
5
|
from careamics.config import (
|
|
@@ -14,9 +16,24 @@ logger = get_logger(__name__)
|
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class N2VModule(UnetModule):
|
|
17
|
-
"""CAREamics PyTorch Lightning module for N2V algorithm.
|
|
19
|
+
"""CAREamics PyTorch Lightning module for N2V algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : N2VAlgorithm or dict
|
|
24
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
18
27
|
|
|
19
28
|
def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate N2V DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : N2VAlgorithm or dict
|
|
34
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
35
|
+
dictionary.
|
|
36
|
+
"""
|
|
20
37
|
super().__init__(algorithm_config)
|
|
21
38
|
|
|
22
39
|
assert isinstance(
|
|
@@ -29,6 +46,7 @@ class N2VModule(UnetModule):
|
|
|
29
46
|
self.loss_func = n2v_loss
|
|
30
47
|
|
|
31
48
|
def _load_best_checkpoint(self) -> None:
|
|
49
|
+
"""Load the best checkpoint for N2V model."""
|
|
32
50
|
logger.warning(
|
|
33
51
|
"Loading best checkpoint for N2V model. Note that for N2V, "
|
|
34
52
|
"the checkpoint with the best validation metrics may not necessarily "
|
|
@@ -41,7 +59,20 @@ class N2VModule(UnetModule):
|
|
|
41
59
|
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
42
60
|
batch_idx: Any,
|
|
43
61
|
) -> Any:
|
|
44
|
-
"""Training step for N2V model.
|
|
62
|
+
"""Training step for N2V model.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
67
|
+
A tuple containing the input data and the target data.
|
|
68
|
+
batch_idx : Any
|
|
69
|
+
The index of the current batch in the training loop.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Any
|
|
74
|
+
The loss value for the current training step.
|
|
75
|
+
"""
|
|
45
76
|
x = batch[0]
|
|
46
77
|
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
47
78
|
prediction = self.model(x_masked)
|
|
@@ -56,7 +87,15 @@ class N2VModule(UnetModule):
|
|
|
56
87
|
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
57
88
|
batch_idx: Any,
|
|
58
89
|
) -> None:
|
|
59
|
-
"""Validation step for N2V model.
|
|
90
|
+
"""Validation step for N2V model.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
95
|
+
A tuple containing the input data and the target data.
|
|
96
|
+
batch_idx : Any
|
|
97
|
+
The index of the current batch in the training loop.
|
|
98
|
+
"""
|
|
60
99
|
x = batch[0]
|
|
61
100
|
|
|
62
101
|
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Generic UNet Lightning DataModule."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Union
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -18,11 +20,27 @@ logger = get_logger(__name__)
|
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
class UnetModule(L.LightningModule):
|
|
21
|
-
"""CAREamics PyTorch Lightning module for UNet based algorithms.
|
|
23
|
+
"""CAREamics PyTorch Lightning module for UNet based algorithms.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
28
|
+
Configuration for the algorithm, either as an instance of a specific algorithm
|
|
29
|
+
class or a dictionary that can be converted to an algorithm instance.
|
|
30
|
+
"""
|
|
22
31
|
|
|
23
32
|
def __init__(
|
|
24
33
|
self, algorithm_config: Union[CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, dict]
|
|
25
34
|
) -> None:
|
|
35
|
+
"""Instantiate UNet DataModule.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
40
|
+
Configuration for the algorithm, either as an instance of a specific
|
|
41
|
+
algorithm class or a dictionary that can be converted to an algorithm
|
|
42
|
+
instance.
|
|
43
|
+
"""
|
|
26
44
|
super().__init__()
|
|
27
45
|
|
|
28
46
|
if isinstance(algorithm_config, dict):
|
|
@@ -37,10 +55,30 @@ class UnetModule(L.LightningModule):
|
|
|
37
55
|
self.metrics = MetricCollection(PeakSignalNoiseRatio())
|
|
38
56
|
|
|
39
57
|
def forward(self, x: Any) -> Any:
|
|
40
|
-
"""Default forward method.
|
|
58
|
+
"""Default forward method.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
x : Any
|
|
63
|
+
Input data.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Any
|
|
68
|
+
Output from the model.
|
|
69
|
+
"""
|
|
41
70
|
return self.model(x)
|
|
42
71
|
|
|
43
72
|
def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
|
|
73
|
+
"""Log training statistics.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
loss : Any
|
|
78
|
+
The loss value for the current training step.
|
|
79
|
+
batch_size : Any
|
|
80
|
+
The size of the batch used in the current training step.
|
|
81
|
+
"""
|
|
44
82
|
self.log(
|
|
45
83
|
"train_loss",
|
|
46
84
|
loss,
|
|
@@ -66,6 +104,15 @@ class UnetModule(L.LightningModule):
|
|
|
66
104
|
)
|
|
67
105
|
|
|
68
106
|
def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
|
|
107
|
+
"""Log validation statistics.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
loss : Any
|
|
112
|
+
The loss value for the current validation step.
|
|
113
|
+
batch_size : Any
|
|
114
|
+
The size of the batch used in the current validation step.
|
|
115
|
+
"""
|
|
69
116
|
self.log(
|
|
70
117
|
"val_loss",
|
|
71
118
|
loss,
|
|
@@ -78,6 +125,7 @@ class UnetModule(L.LightningModule):
|
|
|
78
125
|
self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
|
|
79
126
|
|
|
80
127
|
def _load_best_checkpoint(self) -> None:
|
|
128
|
+
"""Load the best checkpoint from the trainer's checkpoint callback."""
|
|
81
129
|
if (
|
|
82
130
|
not hasattr(self.trainer, "checkpoint_callback")
|
|
83
131
|
or self.trainer.checkpoint_callback is None
|
|
@@ -99,7 +147,22 @@ class UnetModule(L.LightningModule):
|
|
|
99
147
|
batch_idx: Any,
|
|
100
148
|
load_best_checkpoint=False,
|
|
101
149
|
) -> Any:
|
|
102
|
-
"""Default predict step.
|
|
150
|
+
"""Default predict step.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
155
|
+
A tuple containing the input data and optionally the target data.
|
|
156
|
+
batch_idx : Any
|
|
157
|
+
The index of the current batch in the prediction loop.
|
|
158
|
+
load_best_checkpoint : bool, default=False
|
|
159
|
+
Whether to load the best checkpoint before making predictions.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
Any
|
|
164
|
+
The output batch containing the predictions.
|
|
165
|
+
"""
|
|
103
166
|
if self._best_checkpoint_loaded is False and load_best_checkpoint:
|
|
104
167
|
self._load_best_checkpoint()
|
|
105
168
|
self._best_checkpoint_loaded = True
|
|
@@ -127,7 +190,13 @@ class UnetModule(L.LightningModule):
|
|
|
127
190
|
return output_batch
|
|
128
191
|
|
|
129
192
|
def configure_optimizers(self) -> Any:
|
|
130
|
-
"""Configure optimizers.
|
|
193
|
+
"""Configure optimizers.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
Any
|
|
198
|
+
A dictionary containing the optimizer and learning rate scheduler.
|
|
199
|
+
"""
|
|
131
200
|
optimizer_func = get_optimizer(self.config.optimizer.name)
|
|
132
201
|
optimizer = optimizer_func(
|
|
133
202
|
self.model.parameters(), **self.config.optimizer.parameters
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Training and validation Lightning data modules."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|