careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +83 -62
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -0
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +2 -0
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +1 -79
- careamics/config/configuration_model.py +12 -7
- careamics/config/data_model.py +29 -10
- careamics/config/inference_model.py +12 -2
- careamics/config/optimizer_models.py +6 -0
- careamics/config/support/supported_data.py +29 -4
- careamics/config/tile_information.py +10 -0
- careamics/config/training_model.py +5 -1
- careamics/dataset/dataset_utils/__init__.py +0 -6
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
- careamics/dataset/in_memory_dataset.py +37 -21
- careamics/dataset/iterable_dataset.py +38 -34
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/patching.py +53 -37
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +1 -1
- careamics/prediction_utils/__init__.py +0 -2
- careamics/prediction_utils/prediction_outputs.py +18 -46
- careamics/prediction_utils/stitch_prediction.py +17 -14
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
- careamics/config/configuration_example.py +0 -86
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/prediction_utils/create_pred_datamodule.py +0 -185
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
"""Training and validation Lightning data modules."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
from torch.utils.data import DataLoader
|
|
9
10
|
|
|
10
11
|
from careamics.config import DataConfig
|
|
@@ -12,7 +13,6 @@ from careamics.config.data_model import TRANSFORMS_UNION
|
|
|
12
13
|
from careamics.config.support import SupportedData
|
|
13
14
|
from careamics.dataset.dataset_utils import (
|
|
14
15
|
get_files_size,
|
|
15
|
-
get_read_func,
|
|
16
16
|
list_files,
|
|
17
17
|
validate_source_target_files,
|
|
18
18
|
)
|
|
@@ -22,6 +22,7 @@ from careamics.dataset.in_memory_dataset import (
|
|
|
22
22
|
from careamics.dataset.iterable_dataset import (
|
|
23
23
|
PathIterableDataset,
|
|
24
24
|
)
|
|
25
|
+
from careamics.file_io.read import get_read_func
|
|
25
26
|
from careamics.utils import get_logger, get_ram_size
|
|
26
27
|
|
|
27
28
|
DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
@@ -29,7 +30,7 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
|
29
30
|
logger = get_logger(__name__)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
class
|
|
33
|
+
class TrainDataModule(L.LightningDataModule):
|
|
33
34
|
"""
|
|
34
35
|
CAREamics Ligthning training and validation data module.
|
|
35
36
|
|
|
@@ -59,18 +60,18 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
59
60
|
----------
|
|
60
61
|
data_config : DataModel
|
|
61
62
|
Pydantic model for CAREamics data configuration.
|
|
62
|
-
train_data :
|
|
63
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
63
64
|
Training data, can be a path to a folder, a file or a numpy array.
|
|
64
|
-
val_data :
|
|
65
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
65
66
|
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
66
67
|
default None.
|
|
67
|
-
train_data_target :
|
|
68
|
+
train_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
68
69
|
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
69
70
|
default None.
|
|
70
|
-
val_data_target :
|
|
71
|
+
val_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
71
72
|
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
72
73
|
by default None.
|
|
73
|
-
read_source_func :
|
|
74
|
+
read_source_func : Callable, optional
|
|
74
75
|
Function to read the source data, by default None. Only used for `custom`
|
|
75
76
|
data type (see DataModel).
|
|
76
77
|
extension_filter : str, optional
|
|
@@ -95,13 +96,13 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
95
96
|
Batch size.
|
|
96
97
|
use_in_memory : bool
|
|
97
98
|
Whether to use in memory dataset if possible.
|
|
98
|
-
train_data :
|
|
99
|
+
train_data : pathlib.Path or numpy.ndarray
|
|
99
100
|
Training data.
|
|
100
|
-
val_data :
|
|
101
|
+
val_data : pathlib.Path or numpy.ndarray
|
|
101
102
|
Validation data.
|
|
102
|
-
train_data_target :
|
|
103
|
+
train_data_target : pathlib.Path or numpy.ndarray
|
|
103
104
|
Training target data.
|
|
104
|
-
val_data_target :
|
|
105
|
+
val_data_target : pathlib.Path or numpy.ndarray
|
|
105
106
|
Validation target data.
|
|
106
107
|
val_percentage : float
|
|
107
108
|
Percentage of the training data to use for validation, if no validation data is
|
|
@@ -118,10 +119,10 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
118
119
|
def __init__(
|
|
119
120
|
self,
|
|
120
121
|
data_config: DataConfig,
|
|
121
|
-
train_data: Union[Path, str,
|
|
122
|
-
val_data: Optional[Union[Path, str,
|
|
123
|
-
train_data_target: Optional[Union[Path, str,
|
|
124
|
-
val_data_target: Optional[Union[Path, str,
|
|
122
|
+
train_data: Union[Path, str, NDArray],
|
|
123
|
+
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
124
|
+
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
125
|
+
val_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
125
126
|
read_source_func: Optional[Callable] = None,
|
|
126
127
|
extension_filter: str = "",
|
|
127
128
|
val_percentage: float = 0.1,
|
|
@@ -135,18 +136,18 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
135
136
|
----------
|
|
136
137
|
data_config : DataModel
|
|
137
138
|
Pydantic model for CAREamics data configuration.
|
|
138
|
-
train_data :
|
|
139
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
139
140
|
Training data, can be a path to a folder, a file or a numpy array.
|
|
140
|
-
val_data :
|
|
141
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
141
142
|
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
142
143
|
default None.
|
|
143
|
-
train_data_target :
|
|
144
|
+
train_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
144
145
|
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
145
146
|
default None.
|
|
146
|
-
val_data_target :
|
|
147
|
+
val_data_target : pathlib.Path or str or numpy.ndarray, optional
|
|
147
148
|
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
148
149
|
by default None.
|
|
149
|
-
read_source_func :
|
|
150
|
+
read_source_func : Callable, optional
|
|
150
151
|
Function to read the source data, by default None. Only used for `custom`
|
|
151
152
|
data type (see DataModel).
|
|
152
153
|
extension_filter : str, optional
|
|
@@ -166,7 +167,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
166
167
|
NotImplementedError
|
|
167
168
|
Raised if target data is provided.
|
|
168
169
|
ValueError
|
|
169
|
-
If the input types are mixed (e.g. Path and
|
|
170
|
+
If the input types are mixed (e.g. Path and numpy.ndarray).
|
|
170
171
|
ValueError
|
|
171
172
|
If the data type is `custom` and no `read_source_func` is provided.
|
|
172
173
|
ValueError
|
|
@@ -223,21 +224,21 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
223
224
|
self.use_in_memory: bool = use_in_memory
|
|
224
225
|
|
|
225
226
|
# data: make data Path or np.ndarray, use type annotations for mypy
|
|
226
|
-
self.train_data: Union[Path,
|
|
227
|
+
self.train_data: Union[Path, NDArray] = (
|
|
227
228
|
Path(train_data) if isinstance(train_data, str) else train_data
|
|
228
229
|
)
|
|
229
230
|
|
|
230
|
-
self.val_data: Union[Path,
|
|
231
|
+
self.val_data: Union[Path, NDArray] = (
|
|
231
232
|
Path(val_data) if isinstance(val_data, str) else val_data
|
|
232
233
|
)
|
|
233
234
|
|
|
234
|
-
self.train_data_target: Union[Path,
|
|
235
|
+
self.train_data_target: Union[Path, NDArray] = (
|
|
235
236
|
Path(train_data_target)
|
|
236
237
|
if isinstance(train_data_target, str)
|
|
237
238
|
else train_data_target
|
|
238
239
|
)
|
|
239
240
|
|
|
240
|
-
self.val_data_target: Union[Path,
|
|
241
|
+
self.val_data_target: Union[Path, NDArray] = (
|
|
241
242
|
Path(val_data_target)
|
|
242
243
|
if isinstance(val_data_target, str)
|
|
243
244
|
else val_data_target
|
|
@@ -260,7 +261,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
260
261
|
self.extension_filter: str = extension_filter
|
|
261
262
|
|
|
262
263
|
# Pytorch dataloader parameters
|
|
263
|
-
self.dataloader_params:
|
|
264
|
+
self.dataloader_params: dict[str, Any] = (
|
|
264
265
|
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
265
266
|
)
|
|
266
267
|
|
|
@@ -298,7 +299,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
298
299
|
|
|
299
300
|
# same for target data
|
|
300
301
|
if self.train_data_target is not None:
|
|
301
|
-
self.train_target_files:
|
|
302
|
+
self.train_target_files: list[Path] = list_files(
|
|
302
303
|
self.train_data_target, self.data_type, self.extension_filter
|
|
303
304
|
)
|
|
304
305
|
|
|
@@ -403,7 +404,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
403
404
|
)
|
|
404
405
|
|
|
405
406
|
# create validation dataset
|
|
406
|
-
if self.
|
|
407
|
+
if self.val_data is not None:
|
|
407
408
|
# create its own dataset
|
|
408
409
|
self.val_dataset = PathIterableDataset(
|
|
409
410
|
data_config=self.data_config,
|
|
@@ -423,9 +424,19 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
423
424
|
# extract validation from the training patches
|
|
424
425
|
self.val_dataset = self.train_dataset.split_dataset(
|
|
425
426
|
percentage=self.val_percentage,
|
|
426
|
-
|
|
427
|
+
minimum_number=self.val_minimum_split,
|
|
427
428
|
)
|
|
428
429
|
|
|
430
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
431
|
+
"""Return training data statistics.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
tuple of list
|
|
436
|
+
Means and standard deviations across channels of the training data.
|
|
437
|
+
"""
|
|
438
|
+
return self.train_dataset.get_data_statistics()
|
|
439
|
+
|
|
429
440
|
def train_dataloader(self) -> Any:
|
|
430
441
|
"""
|
|
431
442
|
Create a dataloader for training.
|
|
@@ -454,12 +465,30 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
454
465
|
)
|
|
455
466
|
|
|
456
467
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
468
|
+
def create_train_datamodule(
|
|
469
|
+
train_data: Union[str, Path, NDArray],
|
|
470
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
471
|
+
patch_size: list[int],
|
|
472
|
+
axes: str,
|
|
473
|
+
batch_size: int,
|
|
474
|
+
val_data: Optional[Union[str, Path, NDArray]] = None,
|
|
475
|
+
transforms: Optional[list[TRANSFORMS_UNION]] = None,
|
|
476
|
+
train_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
477
|
+
val_target_data: Optional[Union[str, Path, NDArray]] = None,
|
|
478
|
+
read_source_func: Optional[Callable] = None,
|
|
479
|
+
extension_filter: str = "",
|
|
480
|
+
val_percentage: float = 0.1,
|
|
481
|
+
val_minimum_patches: int = 5,
|
|
482
|
+
dataloader_params: Optional[dict] = None,
|
|
483
|
+
use_in_memory: bool = True,
|
|
484
|
+
use_n2v2: bool = False,
|
|
485
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
486
|
+
struct_n2v_span: int = 5,
|
|
487
|
+
) -> TrainDataModule:
|
|
488
|
+
"""Create a TrainDataModule.
|
|
489
|
+
|
|
490
|
+
This function is used to explicitely pass the parameters usually contained in a
|
|
491
|
+
`data_model` configuration to a TrainDataModule.
|
|
463
492
|
|
|
464
493
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
465
494
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
@@ -501,26 +530,26 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
501
530
|
|
|
502
531
|
Parameters
|
|
503
532
|
----------
|
|
504
|
-
train_data :
|
|
533
|
+
train_data : pathlib.Path or str or numpy.ndarray
|
|
505
534
|
Training data.
|
|
506
|
-
data_type :
|
|
535
|
+
data_type : {"array", "tiff", "custom"}
|
|
507
536
|
Data type, see `SupportedData` for available options.
|
|
508
|
-
patch_size :
|
|
537
|
+
patch_size : list of int
|
|
509
538
|
Patch size, 2D or 3D patch size.
|
|
510
539
|
axes : str
|
|
511
540
|
Axes of the data, choosen amongst SCZYX.
|
|
512
541
|
batch_size : int
|
|
513
542
|
Batch size.
|
|
514
|
-
val_data :
|
|
543
|
+
val_data : pathlib.Path or str or numpy.ndarray, optional
|
|
515
544
|
Validation data, by default None.
|
|
516
|
-
transforms :
|
|
545
|
+
transforms : list of Transforms, optional
|
|
517
546
|
List of transforms to apply to training patches. If None, default transforms
|
|
518
547
|
are applied.
|
|
519
|
-
train_target_data :
|
|
548
|
+
train_target_data : pathlib.Path or str or numpy.ndarray, optional
|
|
520
549
|
Training target data, by default None.
|
|
521
|
-
val_target_data :
|
|
550
|
+
val_target_data : pathlib.Path or str or numpy.ndarray, optional
|
|
522
551
|
Validation target data, by default None.
|
|
523
|
-
read_source_func :
|
|
552
|
+
read_source_func : Callable, optional
|
|
524
553
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
525
554
|
default None.
|
|
526
555
|
extension_filter : str, optional
|
|
@@ -537,19 +566,24 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
537
566
|
Use in memory dataset if possible, by default True.
|
|
538
567
|
use_n2v2 : bool, optional
|
|
539
568
|
Use N2V2 transformation during training, by default False.
|
|
540
|
-
struct_n2v_axis :
|
|
569
|
+
struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
|
|
541
570
|
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
542
571
|
default "none".
|
|
543
572
|
struct_n2v_span : int, optional
|
|
544
573
|
Span for the structN2V mask, by default 5.
|
|
545
574
|
|
|
575
|
+
Returns
|
|
576
|
+
-------
|
|
577
|
+
TrainDataModule
|
|
578
|
+
CAREamics training Lightning data module.
|
|
579
|
+
|
|
546
580
|
Examples
|
|
547
581
|
--------
|
|
548
|
-
Create a
|
|
582
|
+
Create a TrainingDataModule with default transforms with a numpy array:
|
|
549
583
|
>>> import numpy as np
|
|
550
|
-
>>> from careamics import
|
|
584
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
551
585
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
552
|
-
>>> data_module =
|
|
586
|
+
>>> data_module = create_train_datamodule(
|
|
553
587
|
... train_data=my_array,
|
|
554
588
|
... data_type="array",
|
|
555
589
|
... patch_size=(8, 8),
|
|
@@ -560,12 +594,12 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
560
594
|
For custom data types (those not supported by CAREamics), then one can pass a read
|
|
561
595
|
function and a filter for the files extension:
|
|
562
596
|
>>> import numpy as np
|
|
563
|
-
>>> from careamics import
|
|
597
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
564
598
|
>>>
|
|
565
599
|
>>> def read_npy(path):
|
|
566
600
|
... return np.load(path)
|
|
567
601
|
>>>
|
|
568
|
-
>>> data_module =
|
|
602
|
+
>>> data_module = create_train_datamodule(
|
|
569
603
|
... train_data="path/to/data",
|
|
570
604
|
... data_type="custom",
|
|
571
605
|
... patch_size=(8, 8),
|
|
@@ -578,7 +612,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
578
612
|
If you want to use a different set of transformations, you can pass a list of
|
|
579
613
|
transforms:
|
|
580
614
|
>>> import numpy as np
|
|
581
|
-
>>> from careamics import
|
|
615
|
+
>>> from careamics.lightning import create_train_datamodule
|
|
582
616
|
>>> from careamics.config.support import SupportedTransform
|
|
583
617
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
584
618
|
>>> my_transforms = [
|
|
@@ -586,7 +620,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
586
620
|
... "name": SupportedTransform.XY_FLIP.value,
|
|
587
621
|
... }
|
|
588
622
|
... ]
|
|
589
|
-
>>> data_module =
|
|
623
|
+
>>> data_module = create_train_datamodule(
|
|
590
624
|
... train_data=my_array,
|
|
591
625
|
... data_type="array",
|
|
592
626
|
... patch_size=(8, 8),
|
|
@@ -595,166 +629,52 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
595
629
|
... transforms=my_transforms,
|
|
596
630
|
... )
|
|
597
631
|
"""
|
|
632
|
+
if dataloader_params is None:
|
|
633
|
+
dataloader_params = {}
|
|
634
|
+
|
|
635
|
+
data_dict: dict[str, Any] = {
|
|
636
|
+
"mode": "train",
|
|
637
|
+
"data_type": data_type,
|
|
638
|
+
"patch_size": patch_size,
|
|
639
|
+
"axes": axes,
|
|
640
|
+
"batch_size": batch_size,
|
|
641
|
+
"dataloader_params": dataloader_params,
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
# if transforms are passed (otherwise it will use the default ones)
|
|
645
|
+
if transforms is not None:
|
|
646
|
+
data_dict["transforms"] = transforms
|
|
647
|
+
|
|
648
|
+
# validate configuration
|
|
649
|
+
data_config = DataConfig(**data_dict)
|
|
650
|
+
|
|
651
|
+
# N2V specific checks, N2V, structN2V, and transforms
|
|
652
|
+
if data_config.has_n2v_manipulate():
|
|
653
|
+
# there is not target, n2v2 and structN2V can be changed
|
|
654
|
+
if train_target_data is None:
|
|
655
|
+
data_config.set_N2V2(use_n2v2)
|
|
656
|
+
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
657
|
+
else:
|
|
658
|
+
raise ValueError(
|
|
659
|
+
"Cannot have both supervised training (target data) and "
|
|
660
|
+
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
661
|
+
"that is compatible with your supervised training."
|
|
662
|
+
)
|
|
598
663
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
train_target_data
|
|
609
|
-
val_target_data
|
|
610
|
-
read_source_func
|
|
611
|
-
extension_filter
|
|
612
|
-
val_percentage
|
|
613
|
-
val_minimum_patches
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
use_n2v2: bool = False,
|
|
617
|
-
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
618
|
-
struct_n2v_span: int = 5,
|
|
619
|
-
) -> None:
|
|
620
|
-
"""
|
|
621
|
-
LightningDataModule wrapper for training and validation datasets.
|
|
622
|
-
|
|
623
|
-
Since the lightning datamodule has no access to the model, make sure that the
|
|
624
|
-
parameters passed to the datamodule are consistent with the model's requirements
|
|
625
|
-
and are coherent.
|
|
626
|
-
|
|
627
|
-
The data module can be used with Path, str or numpy arrays. In the case of
|
|
628
|
-
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
629
|
-
inputs, it calculates the total file size and estimate whether it can fit in
|
|
630
|
-
memory. If it does not, it iterates through the files. This behaviour can be
|
|
631
|
-
deactivated by setting `use_in_memory` to False, in which case it will
|
|
632
|
-
always use the iterating dataset to train on a Path or str.
|
|
633
|
-
|
|
634
|
-
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
635
|
-
`train_data`.
|
|
636
|
-
|
|
637
|
-
In particular, N2V requires a specific transformation (N2V manipulates), which
|
|
638
|
-
is not compatible with supervised training. The default transformations applied
|
|
639
|
-
to the training patches are defined in `careamics.config.data_model`. To use
|
|
640
|
-
different transformations, pass a list of transforms. See examples for more
|
|
641
|
-
details.
|
|
642
|
-
|
|
643
|
-
By default, CAREamics only supports types defined in
|
|
644
|
-
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
645
|
-
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
646
|
-
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
|
|
647
|
-
(e.g. "*.jpeg") to filter the files extension using `extension_filter`.
|
|
648
|
-
|
|
649
|
-
In the absence of validation data, the validation data is extracted from the
|
|
650
|
-
training data. The percentage of the training data to use for validation, as
|
|
651
|
-
well as the minimum number of patches to split from the training data for
|
|
652
|
-
validation can be set using `val_percentage` and `val_minimum_patches`,
|
|
653
|
-
respectively.
|
|
654
|
-
|
|
655
|
-
In `dataloader_params`, you can pass any parameter accepted by PyTorch
|
|
656
|
-
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
657
|
-
parameter.
|
|
658
|
-
|
|
659
|
-
Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
|
|
660
|
-
to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
|
|
661
|
-
define the axis and span of the structN2V mask. These parameters are without
|
|
662
|
-
effect if a `train_target_data` or if `transforms` are provided.
|
|
663
|
-
|
|
664
|
-
Parameters
|
|
665
|
-
----------
|
|
666
|
-
train_data : Union[str, Path, np.ndarray]
|
|
667
|
-
Training data.
|
|
668
|
-
data_type : Union[str, SupportedData]
|
|
669
|
-
Data type, see `SupportedData` for available options.
|
|
670
|
-
patch_size : List[int]
|
|
671
|
-
Patch size, 2D or 3D patch size.
|
|
672
|
-
axes : str
|
|
673
|
-
Axes of the data, choosen amongst SCZYX.
|
|
674
|
-
batch_size : int
|
|
675
|
-
Batch size.
|
|
676
|
-
val_data : Optional[Union[str, Path]], optional
|
|
677
|
-
Validation data, by default None.
|
|
678
|
-
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
679
|
-
List of transforms to apply to training patches. If None, default transforms
|
|
680
|
-
are applied.
|
|
681
|
-
train_target_data : Optional[Union[str, Path]], optional
|
|
682
|
-
Training target data, by default None.
|
|
683
|
-
val_target_data : Optional[Union[str, Path]], optional
|
|
684
|
-
Validation target data, by default None.
|
|
685
|
-
read_source_func : Optional[Callable], optional
|
|
686
|
-
Function to read the source data, used if `data_type` is `custom`, by
|
|
687
|
-
default None.
|
|
688
|
-
extension_filter : str, optional
|
|
689
|
-
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
690
|
-
val_percentage : float, optional
|
|
691
|
-
Percentage of the training data to use for validation if no validation data
|
|
692
|
-
is given, by default 0.1.
|
|
693
|
-
val_minimum_patches : int, optional
|
|
694
|
-
Minimum number of patches to split from the training data for validation if
|
|
695
|
-
no validation data is given, by default 5.
|
|
696
|
-
dataloader_params : dict, optional
|
|
697
|
-
Pytorch dataloader parameters, by default {}.
|
|
698
|
-
use_in_memory : bool, optional
|
|
699
|
-
Use in memory dataset if possible, by default True.
|
|
700
|
-
use_n2v2 : bool, optional
|
|
701
|
-
Use N2V2 transformation during training, by default False.
|
|
702
|
-
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
703
|
-
Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
|
|
704
|
-
default "none".
|
|
705
|
-
struct_n2v_span : int, optional
|
|
706
|
-
Span for the structN2V mask, by default 5.
|
|
707
|
-
|
|
708
|
-
Raises
|
|
709
|
-
------
|
|
710
|
-
ValueError
|
|
711
|
-
If a target is set and N2V manipulation is present in the transforms.
|
|
712
|
-
"""
|
|
713
|
-
if dataloader_params is None:
|
|
714
|
-
dataloader_params = {}
|
|
715
|
-
data_dict: Dict[str, Any] = {
|
|
716
|
-
"mode": "train",
|
|
717
|
-
"data_type": data_type,
|
|
718
|
-
"patch_size": patch_size,
|
|
719
|
-
"axes": axes,
|
|
720
|
-
"batch_size": batch_size,
|
|
721
|
-
"dataloader_params": dataloader_params,
|
|
722
|
-
}
|
|
723
|
-
|
|
724
|
-
# if transforms are passed (otherwise it will use the default ones)
|
|
725
|
-
if transforms is not None:
|
|
726
|
-
data_dict["transforms"] = transforms
|
|
727
|
-
|
|
728
|
-
# validate configuration
|
|
729
|
-
self.data_config = DataConfig(**data_dict)
|
|
730
|
-
|
|
731
|
-
# N2V specific checks, N2V, structN2V, and transforms
|
|
732
|
-
if self.data_config.has_n2v_manipulate():
|
|
733
|
-
# there is not target, n2v2 and structN2V can be changed
|
|
734
|
-
if train_target_data is None:
|
|
735
|
-
self.data_config.set_N2V2(use_n2v2)
|
|
736
|
-
self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
737
|
-
else:
|
|
738
|
-
raise ValueError(
|
|
739
|
-
"Cannot have both supervised training (target data) and "
|
|
740
|
-
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
741
|
-
"that is compatible with your supervised training."
|
|
742
|
-
)
|
|
743
|
-
|
|
744
|
-
# sanity check on the dataloader parameters
|
|
745
|
-
if "batch_size" in dataloader_params:
|
|
746
|
-
# remove it
|
|
747
|
-
del dataloader_params["batch_size"]
|
|
748
|
-
|
|
749
|
-
super().__init__(
|
|
750
|
-
data_config=self.data_config,
|
|
751
|
-
train_data=train_data,
|
|
752
|
-
val_data=val_data,
|
|
753
|
-
train_data_target=train_target_data,
|
|
754
|
-
val_data_target=val_target_data,
|
|
755
|
-
read_source_func=read_source_func,
|
|
756
|
-
extension_filter=extension_filter,
|
|
757
|
-
val_percentage=val_percentage,
|
|
758
|
-
val_minimum_split=val_minimum_patches,
|
|
759
|
-
use_in_memory=use_in_memory,
|
|
760
|
-
)
|
|
664
|
+
# sanity check on the dataloader parameters
|
|
665
|
+
if "batch_size" in dataloader_params:
|
|
666
|
+
# remove it
|
|
667
|
+
del dataloader_params["batch_size"]
|
|
668
|
+
|
|
669
|
+
return TrainDataModule(
|
|
670
|
+
data_config=data_config,
|
|
671
|
+
train_data=train_data,
|
|
672
|
+
val_data=val_data,
|
|
673
|
+
train_data_target=train_target_data,
|
|
674
|
+
val_data_target=val_target_data,
|
|
675
|
+
read_source_func=read_source_func,
|
|
676
|
+
extension_filter=extension_filter,
|
|
677
|
+
val_percentage=val_percentage,
|
|
678
|
+
val_minimum_split=val_minimum_patches,
|
|
679
|
+
use_in_memory=use_in_memory,
|
|
680
|
+
)
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -12,7 +12,7 @@ from torch import __version__, load, save
|
|
|
12
12
|
|
|
13
13
|
from careamics.config import Configuration, load_configuration, save_configuration
|
|
14
14
|
from careamics.config.support import SupportedArchitecture
|
|
15
|
-
from careamics.lightning_module import CAREamicsModule
|
|
15
|
+
from careamics.lightning.lightning_module import CAREamicsModule
|
|
16
16
|
|
|
17
17
|
from .bioimage import (
|
|
18
18
|
create_env_text,
|
|
@@ -6,7 +6,7 @@ from typing import Tuple, Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from careamics.config import Configuration
|
|
9
|
-
from careamics.lightning_module import CAREamicsModule
|
|
9
|
+
from careamics.lightning.lightning_module import CAREamicsModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
12
12
|
|
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
"""Package to house various prediction utilies."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"create_pred_datamodule",
|
|
5
4
|
"stitch_prediction",
|
|
6
5
|
"stitch_prediction_single",
|
|
7
6
|
"convert_outputs",
|
|
8
7
|
]
|
|
9
8
|
|
|
10
|
-
from .create_pred_datamodule import create_pred_datamodule
|
|
11
9
|
from .prediction_outputs import convert_outputs
|
|
12
10
|
from .stitch_prediction import stitch_prediction, stitch_prediction_single
|