careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +30 -29
- careamics/config/__init__.py +12 -9
- careamics/config/algorithm_model.py +5 -5
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +87 -0
- careamics/config/configuration_factory.py +285 -78
- careamics/config/configuration_model.py +22 -23
- careamics/config/data_model.py +62 -160
- careamics/config/inference_model.py +20 -21
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +2 -1
- careamics/config/transformations/n2v_manipulate_model.py +2 -1
- careamics/config/transformations/nd_flip_model.py +7 -12
- careamics/config/transformations/normalize_model.py +2 -1
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +7 -9
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +17 -48
- careamics/dataset/iterable_dataset.py +16 -71
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +123 -49
- careamics/lightning_module.py +7 -7
- careamics/lightning_prediction_datamodule.py +59 -48
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +4 -3
- careamics/model_io/bmz_io.py +8 -7
- careamics/model_io/model_io_utils.py +4 -4
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc3.dist-info/RECORD +0 -109
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,13 @@
|
|
|
1
|
+
"""Training and validation Lightning data modules."""
|
|
2
|
+
|
|
1
3
|
from pathlib import Path
|
|
2
4
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
5
7
|
import pytorch_lightning as L
|
|
6
|
-
from albumentations import Compose
|
|
7
8
|
from torch.utils.data import DataLoader
|
|
8
9
|
|
|
9
|
-
from careamics.config import
|
|
10
|
+
from careamics.config import DataConfig
|
|
10
11
|
from careamics.config.data_model import TRANSFORMS_UNION
|
|
11
12
|
from careamics.config.support import SupportedData
|
|
12
13
|
from careamics.dataset.dataset_utils import (
|
|
@@ -28,9 +29,9 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
|
28
29
|
logger = get_logger(__name__)
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
class
|
|
32
|
+
class CAREamicsTrainData(L.LightningDataModule):
|
|
32
33
|
"""
|
|
33
|
-
|
|
34
|
+
CAREamics Ligthning training and validation data module.
|
|
34
35
|
|
|
35
36
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
36
37
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
@@ -53,11 +54,70 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
53
54
|
|
|
54
55
|
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
55
56
|
"*.czi") to filter the files extension using `extension_filter`.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
data_config : DataModel
|
|
61
|
+
Pydantic model for CAREamics data configuration.
|
|
62
|
+
train_data : Union[Path, str, np.ndarray]
|
|
63
|
+
Training data, can be a path to a folder, a file or a numpy array.
|
|
64
|
+
val_data : Optional[Union[Path, str, np.ndarray]], optional
|
|
65
|
+
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
66
|
+
default None.
|
|
67
|
+
train_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
68
|
+
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
69
|
+
default None.
|
|
70
|
+
val_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
71
|
+
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
72
|
+
by default None.
|
|
73
|
+
read_source_func : Optional[Callable], optional
|
|
74
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
75
|
+
data type (see DataModel).
|
|
76
|
+
extension_filter : str, optional
|
|
77
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
78
|
+
(see DataModel).
|
|
79
|
+
val_percentage : float, optional
|
|
80
|
+
Percentage of the training data to use for validation, by default 0.1. Only
|
|
81
|
+
used if `val_data` is None.
|
|
82
|
+
val_minimum_split : int, optional
|
|
83
|
+
Minimum number of patches or files to split from the training data for
|
|
84
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
85
|
+
use_in_memory : bool, optional
|
|
86
|
+
Use in memory dataset if possible, by default True.
|
|
87
|
+
|
|
88
|
+
Attributes
|
|
89
|
+
----------
|
|
90
|
+
data_config : DataModel
|
|
91
|
+
CAREamics data configuration.
|
|
92
|
+
data_type : SupportedData
|
|
93
|
+
Expected data type, one of "tiff", "array" or "custom".
|
|
94
|
+
batch_size : int
|
|
95
|
+
Batch size.
|
|
96
|
+
use_in_memory : bool
|
|
97
|
+
Whether to use in memory dataset if possible.
|
|
98
|
+
train_data : Union[Path, str, np.ndarray]
|
|
99
|
+
Training data.
|
|
100
|
+
val_data : Optional[Union[Path, str, np.ndarray]]
|
|
101
|
+
Validation data.
|
|
102
|
+
train_data_target : Optional[Union[Path, str, np.ndarray]]
|
|
103
|
+
Training target data.
|
|
104
|
+
val_data_target : Optional[Union[Path, str, np.ndarray]]
|
|
105
|
+
Validation target data.
|
|
106
|
+
val_percentage : float
|
|
107
|
+
Percentage of the training data to use for validation, if no validation data is
|
|
108
|
+
provided.
|
|
109
|
+
val_minimum_split : int
|
|
110
|
+
Minimum number of patches or files to split from the training data for
|
|
111
|
+
validation, if no validation data is provided.
|
|
112
|
+
read_source_func : Optional[Callable]
|
|
113
|
+
Function to read the source data, used if `data_type` is `custom`.
|
|
114
|
+
extension_filter : str
|
|
115
|
+
Filter for file extensions, used if `data_type` is `custom`.
|
|
56
116
|
"""
|
|
57
117
|
|
|
58
118
|
def __init__(
|
|
59
119
|
self,
|
|
60
|
-
data_config:
|
|
120
|
+
data_config: DataConfig,
|
|
61
121
|
train_data: Union[Path, str, np.ndarray],
|
|
62
122
|
val_data: Optional[Union[Path, str, np.ndarray]] = None,
|
|
63
123
|
train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
@@ -98,6 +158,8 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
98
158
|
val_minimum_split : int, optional
|
|
99
159
|
Minimum number of patches or files to split from the training data for
|
|
100
160
|
validation, by default 5. Only used if `val_data` is None.
|
|
161
|
+
use_in_memory : bool, optional
|
|
162
|
+
Use in memory dataset if possible, by default True.
|
|
101
163
|
|
|
102
164
|
Raises
|
|
103
165
|
------
|
|
@@ -128,25 +190,30 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
128
190
|
if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
129
191
|
raise ValueError(
|
|
130
192
|
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
131
|
-
f"specifying a `read_source_func`."
|
|
193
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
132
194
|
)
|
|
133
195
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
train_data, np.ndarray
|
|
196
|
+
# check correct input type
|
|
197
|
+
if (
|
|
198
|
+
isinstance(train_data, np.ndarray)
|
|
199
|
+
and data_config.data_type != SupportedData.ARRAY
|
|
137
200
|
):
|
|
138
201
|
raise ValueError(
|
|
139
|
-
f"
|
|
140
|
-
f"{
|
|
202
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
203
|
+
f"{data_config.data_type}. Set the data type in the configuration "
|
|
204
|
+
f"to {SupportedData.ARRAY} to train on numpy arrays."
|
|
141
205
|
)
|
|
142
206
|
|
|
143
207
|
# and that Path or str are passed, if tiff file type specified
|
|
144
|
-
elif
|
|
145
|
-
|
|
208
|
+
elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
|
|
209
|
+
data_config.data_type != SupportedData.TIFF
|
|
210
|
+
and data_config.data_type != SupportedData.CUSTOM
|
|
146
211
|
):
|
|
147
212
|
raise ValueError(
|
|
148
|
-
f"
|
|
149
|
-
f"
|
|
213
|
+
f"Received a path as input, but the data type was neither set to "
|
|
214
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
215
|
+
f"in the configuration to {SupportedData.TIFF} or "
|
|
216
|
+
f"{SupportedData.CUSTOM} to train on files."
|
|
150
217
|
)
|
|
151
218
|
|
|
152
219
|
# configuration
|
|
@@ -231,7 +298,15 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
231
298
|
validate_source_target_files(self.val_files, self.val_target_files)
|
|
232
299
|
|
|
233
300
|
def setup(self, *args: Any, **kwargs: Any) -> None:
|
|
234
|
-
"""Hook called at the beginning of fit, validate, or predict.
|
|
301
|
+
"""Hook called at the beginning of fit, validate, or predict.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
*args : Any
|
|
306
|
+
Unused.
|
|
307
|
+
**kwargs : Any
|
|
308
|
+
Unused.
|
|
309
|
+
"""
|
|
235
310
|
# if numpy array
|
|
236
311
|
if self.data_type == SupportedData.ARRAY:
|
|
237
312
|
# train dataset
|
|
@@ -266,9 +341,9 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
266
341
|
self.train_dataset = InMemoryDataset(
|
|
267
342
|
data_config=self.data_config,
|
|
268
343
|
inputs=self.train_files,
|
|
269
|
-
data_target=
|
|
270
|
-
|
|
271
|
-
|
|
344
|
+
data_target=(
|
|
345
|
+
self.train_target_files if self.train_data_target else None
|
|
346
|
+
),
|
|
272
347
|
read_source_func=self.read_source_func,
|
|
273
348
|
)
|
|
274
349
|
|
|
@@ -277,9 +352,9 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
277
352
|
self.val_dataset = InMemoryDataset(
|
|
278
353
|
data_config=self.data_config,
|
|
279
354
|
inputs=self.val_files,
|
|
280
|
-
data_target=
|
|
281
|
-
|
|
282
|
-
|
|
355
|
+
data_target=(
|
|
356
|
+
self.val_target_files if self.val_data_target else None
|
|
357
|
+
),
|
|
283
358
|
read_source_func=self.read_source_func,
|
|
284
359
|
)
|
|
285
360
|
else:
|
|
@@ -295,9 +370,9 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
295
370
|
self.train_dataset = PathIterableDataset(
|
|
296
371
|
data_config=self.data_config,
|
|
297
372
|
src_files=self.train_files,
|
|
298
|
-
target_files=
|
|
299
|
-
|
|
300
|
-
|
|
373
|
+
target_files=(
|
|
374
|
+
self.train_target_files if self.train_data_target else None
|
|
375
|
+
),
|
|
301
376
|
read_source_func=self.read_source_func,
|
|
302
377
|
)
|
|
303
378
|
|
|
@@ -307,9 +382,9 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
307
382
|
self.val_dataset = PathIterableDataset(
|
|
308
383
|
data_config=self.data_config,
|
|
309
384
|
src_files=self.val_files,
|
|
310
|
-
target_files=
|
|
311
|
-
|
|
312
|
-
|
|
385
|
+
target_files=(
|
|
386
|
+
self.val_target_files if self.val_data_target else None
|
|
387
|
+
),
|
|
313
388
|
read_source_func=self.read_source_func,
|
|
314
389
|
)
|
|
315
390
|
elif len(self.train_files) <= self.val_minimum_split:
|
|
@@ -353,9 +428,12 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
353
428
|
)
|
|
354
429
|
|
|
355
430
|
|
|
356
|
-
class
|
|
431
|
+
class TrainingDataWrapper(CAREamicsTrainData):
|
|
357
432
|
"""
|
|
358
|
-
|
|
433
|
+
Wrapper around the CAREamics Lightning training data module.
|
|
434
|
+
|
|
435
|
+
This class is used to explicitely pass the parameters usually contained in a
|
|
436
|
+
`data_model` configuration.
|
|
359
437
|
|
|
360
438
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
361
439
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
@@ -374,8 +452,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
374
452
|
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
375
453
|
not compatible with supervised training. The default transformations applied to the
|
|
376
454
|
training patches are defined in `careamics.config.data_model`. To use different
|
|
377
|
-
transformations, pass a list of transforms
|
|
378
|
-
`transforms` parameter. See examples for more details.
|
|
455
|
+
transformations, pass a list of transforms. See examples for more details.
|
|
379
456
|
|
|
380
457
|
By default, CAREamics only supports types defined in
|
|
381
458
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -410,7 +487,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
410
487
|
Batch size.
|
|
411
488
|
val_data : Optional[Union[str, Path]], optional
|
|
412
489
|
Validation data, by default None.
|
|
413
|
-
transforms :
|
|
490
|
+
transforms : List[TRANSFORMS_UNION], optional
|
|
414
491
|
List of transforms to apply to training patches. If None, default transforms
|
|
415
492
|
are applied.
|
|
416
493
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -442,11 +519,11 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
442
519
|
|
|
443
520
|
Examples
|
|
444
521
|
--------
|
|
445
|
-
Create a
|
|
522
|
+
Create a TrainingDataWrapper with default transforms with a numpy array:
|
|
446
523
|
>>> import numpy as np
|
|
447
|
-
>>> from careamics import
|
|
524
|
+
>>> from careamics import TrainingDataWrapper
|
|
448
525
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
449
|
-
>>> data_module =
|
|
526
|
+
>>> data_module = TrainingDataWrapper(
|
|
450
527
|
... train_data=my_array,
|
|
451
528
|
... data_type="array",
|
|
452
529
|
... patch_size=(8, 8),
|
|
@@ -457,12 +534,12 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
457
534
|
For custom data types (those not supported by CAREamics), then one can pass a read
|
|
458
535
|
function and a filter for the files extension:
|
|
459
536
|
>>> import numpy as np
|
|
460
|
-
>>> from careamics import
|
|
537
|
+
>>> from careamics import TrainingDataWrapper
|
|
461
538
|
>>>
|
|
462
539
|
>>> def read_npy(path):
|
|
463
540
|
... return np.load(path)
|
|
464
541
|
>>>
|
|
465
|
-
>>> data_module =
|
|
542
|
+
>>> data_module = TrainingDataWrapper(
|
|
466
543
|
... train_data="path/to/data",
|
|
467
544
|
... data_type="custom",
|
|
468
545
|
... patch_size=(8, 8),
|
|
@@ -475,7 +552,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
475
552
|
If you want to use a different set of transformations, you can pass a list of
|
|
476
553
|
transforms:
|
|
477
554
|
>>> import numpy as np
|
|
478
|
-
>>> from careamics import
|
|
555
|
+
>>> from careamics import TrainingDataWrapper
|
|
479
556
|
>>> from careamics.config.support import SupportedTransform
|
|
480
557
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
481
558
|
>>> my_transforms = [
|
|
@@ -488,7 +565,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
488
565
|
... "name": SupportedTransform.N2V_MANIPULATE.value,
|
|
489
566
|
... }
|
|
490
567
|
... ]
|
|
491
|
-
>>> data_module =
|
|
568
|
+
>>> data_module = TrainingDataWrapper(
|
|
492
569
|
... train_data=my_array,
|
|
493
570
|
... data_type="array",
|
|
494
571
|
... patch_size=(8, 8),
|
|
@@ -506,7 +583,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
506
583
|
axes: str,
|
|
507
584
|
batch_size: int,
|
|
508
585
|
val_data: Optional[Union[str, Path]] = None,
|
|
509
|
-
transforms: Optional[
|
|
586
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
510
587
|
train_target_data: Optional[Union[str, Path]] = None,
|
|
511
588
|
val_target_data: Optional[Union[str, Path]] = None,
|
|
512
589
|
read_source_func: Optional[Callable] = None,
|
|
@@ -539,8 +616,8 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
539
616
|
In particular, N2V requires a specific transformation (N2V manipulates), which
|
|
540
617
|
is not compatible with supervised training. The default transformations applied
|
|
541
618
|
to the training patches are defined in `careamics.config.data_model`. To use
|
|
542
|
-
different transformations, pass a list of transforms
|
|
543
|
-
|
|
619
|
+
different transformations, pass a list of transforms. See examples for more
|
|
620
|
+
details.
|
|
544
621
|
|
|
545
622
|
By default, CAREamics only supports types defined in
|
|
546
623
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -577,7 +654,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
577
654
|
Batch size.
|
|
578
655
|
val_data : Optional[Union[str, Path]], optional
|
|
579
656
|
Validation data, by default None.
|
|
580
|
-
transforms : Optional[
|
|
657
|
+
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
581
658
|
List of transforms to apply to training patches. If None, default transforms
|
|
582
659
|
are applied.
|
|
583
660
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -628,13 +705,10 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
628
705
|
data_dict["transforms"] = transforms
|
|
629
706
|
|
|
630
707
|
# validate configuration
|
|
631
|
-
self.data_config =
|
|
708
|
+
self.data_config = DataConfig(**data_dict)
|
|
632
709
|
|
|
633
710
|
# N2V specific checks, N2V, structN2V, and transforms
|
|
634
|
-
if (
|
|
635
|
-
self.data_config.has_transform_list()
|
|
636
|
-
and self.data_config.has_n2v_manipulate()
|
|
637
|
-
):
|
|
711
|
+
if self.data_config.has_n2v_manipulate():
|
|
638
712
|
# there is not target, n2v2 and structN2V can be changed
|
|
639
713
|
if train_target_data is None:
|
|
640
714
|
self.data_config.set_N2V2(use_n2v2)
|
careamics/lightning_module.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
|
|
|
3
3
|
import pytorch_lightning as L
|
|
4
4
|
from torch import Tensor, nn
|
|
5
5
|
|
|
6
|
-
from careamics.config import
|
|
6
|
+
from careamics.config import AlgorithmConfig
|
|
7
7
|
from careamics.config.support import (
|
|
8
8
|
SupportedAlgorithm,
|
|
9
9
|
SupportedArchitecture,
|
|
@@ -17,7 +17,7 @@ from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
|
17
17
|
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class CAREamicsModule(L.LightningModule):
|
|
21
21
|
"""
|
|
22
22
|
CAREamics Lightning module.
|
|
23
23
|
|
|
@@ -38,7 +38,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
38
38
|
Learning rate scheduler name.
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
|
-
def __init__(self, algorithm_config: Union[
|
|
41
|
+
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
42
42
|
"""
|
|
43
43
|
CAREamics Lightning module.
|
|
44
44
|
|
|
@@ -53,7 +53,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
53
53
|
super().__init__()
|
|
54
54
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
55
55
|
if isinstance(algorithm_config, dict):
|
|
56
|
-
algorithm_config =
|
|
56
|
+
algorithm_config = AlgorithmConfig(**algorithm_config)
|
|
57
57
|
|
|
58
58
|
# create model and loss function
|
|
59
59
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
@@ -162,7 +162,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
162
162
|
mean=self._trainer.datamodule.predict_dataset.mean,
|
|
163
163
|
std=self._trainer.datamodule.predict_dataset.std,
|
|
164
164
|
)
|
|
165
|
-
denormalized_output = denorm(
|
|
165
|
+
denormalized_output, _ = denorm(patch=output)
|
|
166
166
|
|
|
167
167
|
if len(aux) > 0:
|
|
168
168
|
return denormalized_output, aux
|
|
@@ -192,7 +192,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
192
192
|
}
|
|
193
193
|
|
|
194
194
|
|
|
195
|
-
class CAREamicsModule
|
|
195
|
+
class CAREamicsModuleWrapper(CAREamicsModule):
|
|
196
196
|
"""Class defining the API for CAREamics Lightning layer.
|
|
197
197
|
|
|
198
198
|
This class exposes parameters used to create an AlgorithmModel instance, triggering
|
|
@@ -287,6 +287,6 @@ class CAREamicsModule(CAREamicsKiln):
|
|
|
287
287
|
algorithm_configuration["model"] = model_configuration
|
|
288
288
|
|
|
289
289
|
# call the parent init using an AlgorithmModel instance
|
|
290
|
-
super().__init__(
|
|
290
|
+
super().__init__(AlgorithmConfig(**algorithm_configuration))
|
|
291
291
|
|
|
292
292
|
# TODO add load_from_checkpoint wrapper
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
+
"""Prediction Lightning data modules."""
|
|
2
|
+
|
|
1
3
|
from pathlib import Path
|
|
2
4
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
5
7
|
import pytorch_lightning as L
|
|
6
|
-
from albumentations import Compose
|
|
7
8
|
from torch.utils.data import DataLoader
|
|
8
9
|
from torch.utils.data.dataloader import default_collate
|
|
9
10
|
|
|
10
|
-
from careamics.config import
|
|
11
|
+
from careamics.config import InferenceConfig
|
|
11
12
|
from careamics.config.support import SupportedData
|
|
12
13
|
from careamics.config.tile_information import TileInformation
|
|
13
14
|
from careamics.dataset.dataset_utils import (
|
|
@@ -38,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
38
39
|
|
|
39
40
|
Parameters
|
|
40
41
|
----------
|
|
41
|
-
batch :
|
|
42
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
42
43
|
Batch of tiles.
|
|
43
44
|
|
|
44
45
|
Returns
|
|
@@ -62,9 +63,9 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
62
63
|
return default_collate(new_batch)
|
|
63
64
|
|
|
64
65
|
|
|
65
|
-
class
|
|
66
|
+
class CAREamicsPredictData(L.LightningDataModule):
|
|
66
67
|
"""
|
|
67
|
-
|
|
68
|
+
CAREamics Lightning prediction data module.
|
|
68
69
|
|
|
69
70
|
The data module can be used with Path, str or numpy arrays. The data can be either
|
|
70
71
|
a folder containing images or a single file.
|
|
@@ -79,7 +80,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
79
80
|
|
|
80
81
|
Parameters
|
|
81
82
|
----------
|
|
82
|
-
|
|
83
|
+
pred_config : InferenceModel
|
|
83
84
|
Pydantic model for CAREamics prediction configuration.
|
|
84
85
|
pred_data : Union[Path, str, np.ndarray]
|
|
85
86
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
@@ -93,7 +94,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
93
94
|
|
|
94
95
|
def __init__(
|
|
95
96
|
self,
|
|
96
|
-
|
|
97
|
+
pred_config: InferenceConfig,
|
|
97
98
|
pred_data: Union[Path, str, np.ndarray],
|
|
98
99
|
read_source_func: Optional[Callable] = None,
|
|
99
100
|
extension_filter: str = "",
|
|
@@ -115,7 +116,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
115
116
|
|
|
116
117
|
Parameters
|
|
117
118
|
----------
|
|
118
|
-
|
|
119
|
+
pred_config : InferenceModel
|
|
119
120
|
Pydantic model for CAREamics prediction configuration.
|
|
120
121
|
pred_data : Union[Path, str, np.ndarray]
|
|
121
122
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
@@ -142,51 +143,53 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
142
143
|
super().__init__()
|
|
143
144
|
|
|
144
145
|
# check that a read source function is provided for custom types
|
|
145
|
-
if
|
|
146
|
-
prediction_config.data_type == SupportedData.CUSTOM
|
|
147
|
-
and read_source_func is None
|
|
148
|
-
):
|
|
146
|
+
if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
149
147
|
raise ValueError(
|
|
150
148
|
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
151
|
-
f"specifying a `read_source_func`."
|
|
149
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
152
150
|
)
|
|
153
151
|
|
|
154
|
-
#
|
|
155
|
-
|
|
156
|
-
pred_data, np.ndarray
|
|
152
|
+
# check correct input type
|
|
153
|
+
if (
|
|
154
|
+
isinstance(pred_data, np.ndarray)
|
|
155
|
+
and pred_config.data_type != SupportedData.ARRAY
|
|
157
156
|
):
|
|
158
157
|
raise ValueError(
|
|
159
|
-
f"
|
|
160
|
-
f"{
|
|
158
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
159
|
+
f"{pred_config.data_type}. Set the data type "
|
|
160
|
+
f"to {SupportedData.ARRAY} to predict on numpy arrays."
|
|
161
161
|
)
|
|
162
162
|
|
|
163
163
|
# and that Path or str are passed, if tiff file type specified
|
|
164
|
-
elif
|
|
165
|
-
|
|
164
|
+
elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
|
|
165
|
+
pred_config.data_type != SupportedData.TIFF
|
|
166
|
+
and pred_config.data_type != SupportedData.CUSTOM
|
|
166
167
|
):
|
|
167
168
|
raise ValueError(
|
|
168
|
-
f"
|
|
169
|
-
f"
|
|
169
|
+
f"Received a path as input, but the data type was neither set to "
|
|
170
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
171
|
+
f" to {SupportedData.TIFF} or "
|
|
172
|
+
f"{SupportedData.CUSTOM} to predict on files."
|
|
170
173
|
)
|
|
171
174
|
|
|
172
175
|
# configuration data
|
|
173
|
-
self.prediction_config =
|
|
174
|
-
self.data_type =
|
|
175
|
-
self.batch_size =
|
|
176
|
+
self.prediction_config = pred_config
|
|
177
|
+
self.data_type = pred_config.data_type
|
|
178
|
+
self.batch_size = pred_config.batch_size
|
|
176
179
|
self.dataloader_params = dataloader_params
|
|
177
180
|
|
|
178
181
|
self.pred_data = pred_data
|
|
179
|
-
self.tile_size =
|
|
180
|
-
self.tile_overlap =
|
|
182
|
+
self.tile_size = pred_config.tile_size
|
|
183
|
+
self.tile_overlap = pred_config.tile_overlap
|
|
181
184
|
|
|
182
185
|
# read source function
|
|
183
|
-
if
|
|
186
|
+
if pred_config.data_type == SupportedData.CUSTOM:
|
|
184
187
|
# mypy check
|
|
185
188
|
assert read_source_func is not None
|
|
186
189
|
|
|
187
190
|
self.read_source_func: Callable = read_source_func
|
|
188
|
-
elif
|
|
189
|
-
self.read_source_func = get_read_func(
|
|
191
|
+
elif pred_config.data_type != SupportedData.ARRAY:
|
|
192
|
+
self.read_source_func = get_read_func(pred_config.data_type)
|
|
190
193
|
|
|
191
194
|
self.extension_filter = extension_filter
|
|
192
195
|
|
|
@@ -238,9 +241,12 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
238
241
|
) # TODO check workers are used
|
|
239
242
|
|
|
240
243
|
|
|
241
|
-
class
|
|
244
|
+
class PredictDataWrapper(CAREamicsPredictData):
|
|
242
245
|
"""
|
|
243
|
-
|
|
246
|
+
Wrapper around the CAREamics inference Lightning data module.
|
|
247
|
+
|
|
248
|
+
This class is used to explicitely pass the parameters usually contained in a
|
|
249
|
+
`inference_model` configuration.
|
|
244
250
|
|
|
245
251
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
246
252
|
parameters passed to the datamodule are consistent with the model's requirements
|
|
@@ -251,14 +257,13 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
251
257
|
|
|
252
258
|
The default transformations applied to the images are defined in
|
|
253
259
|
`careamics.config.inference_model`. To use different transformations, pass a list
|
|
254
|
-
of transforms
|
|
260
|
+
of transforms. See examples
|
|
255
261
|
for more details.
|
|
256
262
|
|
|
257
263
|
The `mean` and `std` parameters are only used if Normalization is defined either
|
|
258
|
-
in the default transformations or in the `transforms` parameter
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
to this method.
|
|
264
|
+
in the default transformations or in the `transforms` parameter. If you pass a
|
|
265
|
+
`Normalization` transform in a list as `transforms`, then the mean and std
|
|
266
|
+
parameters will be overwritten by those passed to this method.
|
|
262
267
|
|
|
263
268
|
By default, CAREamics only supports types defined in
|
|
264
269
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -270,6 +275,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
270
275
|
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
271
276
|
parameter.
|
|
272
277
|
|
|
278
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
279
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
280
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
281
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
282
|
+
happen in the Z dimension, consider padding your image.
|
|
283
|
+
|
|
273
284
|
Parameters
|
|
274
285
|
----------
|
|
275
286
|
pred_data : Union[str, Path, np.ndarray]
|
|
@@ -292,7 +303,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
292
303
|
Batch size.
|
|
293
304
|
tta_transforms : bool, optional
|
|
294
305
|
Use test time augmentation, by default True.
|
|
295
|
-
transforms :
|
|
306
|
+
transforms : List, optional
|
|
296
307
|
List of transforms to apply to prediction patches. If None, default
|
|
297
308
|
transforms are applied.
|
|
298
309
|
read_source_func : Optional[Callable], optional
|
|
@@ -315,7 +326,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
315
326
|
axes: str = "YX",
|
|
316
327
|
batch_size: int = 1,
|
|
317
328
|
tta_transforms: bool = True,
|
|
318
|
-
transforms: Optional[
|
|
329
|
+
transforms: Optional[List] = None,
|
|
319
330
|
read_source_func: Optional[Callable] = None,
|
|
320
331
|
extension_filter: str = "",
|
|
321
332
|
dataloader_params: Optional[dict] = None,
|
|
@@ -329,6 +340,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
329
340
|
Prediction data.
|
|
330
341
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
331
342
|
Data type, see `SupportedData` for available options.
|
|
343
|
+
mean : float
|
|
344
|
+
Mean value for normalization, only used if Normalization is defined in the
|
|
345
|
+
transforms.
|
|
346
|
+
std : float
|
|
347
|
+
Standard deviation value for normalization, only used if Normalization is
|
|
348
|
+
defined in the transform.
|
|
332
349
|
tile_size : List[int]
|
|
333
350
|
Tile size, 2D or 3D tile size.
|
|
334
351
|
tile_overlap : List[int]
|
|
@@ -339,13 +356,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
339
356
|
Batch size.
|
|
340
357
|
tta_transforms : bool, optional
|
|
341
358
|
Use test time augmentation, by default True.
|
|
342
|
-
|
|
343
|
-
Mean value for normalization, only used if Normalization is defined, by
|
|
344
|
-
default None.
|
|
345
|
-
std : Optional[float], optional
|
|
346
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
347
|
-
defined, by default None.
|
|
348
|
-
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
359
|
+
transforms : Optional[List], optional
|
|
349
360
|
List of transforms to apply to prediction patches. If None, default
|
|
350
361
|
transforms are applied.
|
|
351
362
|
read_source_func : Optional[Callable], optional
|
|
@@ -374,7 +385,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
374
385
|
prediction_dict["transforms"] = transforms
|
|
375
386
|
|
|
376
387
|
# validate configuration
|
|
377
|
-
self.prediction_config =
|
|
388
|
+
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
378
389
|
|
|
379
390
|
# sanity check on the dataloader parameters
|
|
380
391
|
if "batch_size" in dataloader_params:
|
|
@@ -382,7 +393,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
382
393
|
del dataloader_params["batch_size"]
|
|
383
394
|
|
|
384
395
|
super().__init__(
|
|
385
|
-
|
|
396
|
+
pred_config=self.prediction_config,
|
|
386
397
|
pred_data=pred_data,
|
|
387
398
|
read_source_func=read_source_func,
|
|
388
399
|
extension_filter=extension_filter,
|
careamics/losses/__init__.py
CHANGED