careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +18 -18
- careamics/config/__init__.py +12 -8
- careamics/config/algorithm_model.py +5 -5
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +187 -50
- careamics/config/configuration_model.py +8 -7
- careamics/config/data_model.py +3 -3
- careamics/config/inference_model.py +1 -1
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +1 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/nd_flip_model.py +1 -1
- careamics/config/transformations/normalize_model.py +1 -1
- careamics/config/transformations/xy_random_rotate90_model.py +1 -1
- careamics/dataset/in_memory_dataset.py +3 -3
- careamics/dataset/iterable_dataset.py +3 -3
- careamics/lightning_datamodule.py +103 -25
- careamics/lightning_module.py +6 -6
- careamics/lightning_prediction_datamodule.py +44 -38
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +6 -6
- careamics/model_io/model_io_utils.py +4 -4
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/RECORD +27 -26
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Training and validation Lightning data modules."""
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
3
4
|
|
|
@@ -6,7 +7,7 @@ import pytorch_lightning as L
|
|
|
6
7
|
from albumentations import Compose
|
|
7
8
|
from torch.utils.data import DataLoader
|
|
8
9
|
|
|
9
|
-
from careamics.config import
|
|
10
|
+
from careamics.config import DataConfig
|
|
10
11
|
from careamics.config.data_model import TRANSFORMS_UNION
|
|
11
12
|
from careamics.config.support import SupportedData
|
|
12
13
|
from careamics.dataset.dataset_utils import (
|
|
@@ -28,9 +29,9 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
|
|
|
28
29
|
logger = get_logger(__name__)
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
class
|
|
32
|
+
class CAREamicsTrainData(L.LightningDataModule):
|
|
32
33
|
"""
|
|
33
|
-
|
|
34
|
+
CAREamics Ligthning training and validation data module.
|
|
34
35
|
|
|
35
36
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
36
37
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
@@ -53,11 +54,70 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
53
54
|
|
|
54
55
|
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
55
56
|
"*.czi") to filter the files extension using `extension_filter`.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
data_config : DataModel
|
|
61
|
+
Pydantic model for CAREamics data configuration.
|
|
62
|
+
train_data : Union[Path, str, np.ndarray]
|
|
63
|
+
Training data, can be a path to a folder, a file or a numpy array.
|
|
64
|
+
val_data : Optional[Union[Path, str, np.ndarray]], optional
|
|
65
|
+
Validation data, can be a path to a folder, a file or a numpy array, by
|
|
66
|
+
default None.
|
|
67
|
+
train_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
68
|
+
Training target data, can be a path to a folder, a file or a numpy array, by
|
|
69
|
+
default None.
|
|
70
|
+
val_data_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
71
|
+
Validation target data, can be a path to a folder, a file or a numpy array,
|
|
72
|
+
by default None.
|
|
73
|
+
read_source_func : Optional[Callable], optional
|
|
74
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
75
|
+
data type (see DataModel).
|
|
76
|
+
extension_filter : str, optional
|
|
77
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
78
|
+
(see DataModel).
|
|
79
|
+
val_percentage : float, optional
|
|
80
|
+
Percentage of the training data to use for validation, by default 0.1. Only
|
|
81
|
+
used if `val_data` is None.
|
|
82
|
+
val_minimum_split : int, optional
|
|
83
|
+
Minimum number of patches or files to split from the training data for
|
|
84
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
85
|
+
use_in_memory : bool, optional
|
|
86
|
+
Use in memory dataset if possible, by default True.
|
|
87
|
+
|
|
88
|
+
Attributes
|
|
89
|
+
----------
|
|
90
|
+
data_config : DataModel
|
|
91
|
+
CAREamics data configuration.
|
|
92
|
+
data_type : SupportedData
|
|
93
|
+
Expected data type, one of "tiff", "array" or "custom".
|
|
94
|
+
batch_size : int
|
|
95
|
+
Batch size.
|
|
96
|
+
use_in_memory : bool
|
|
97
|
+
Whether to use in memory dataset if possible.
|
|
98
|
+
train_data : Union[Path, str, np.ndarray]
|
|
99
|
+
Training data.
|
|
100
|
+
val_data : Optional[Union[Path, str, np.ndarray]]
|
|
101
|
+
Validation data.
|
|
102
|
+
train_data_target : Optional[Union[Path, str, np.ndarray]]
|
|
103
|
+
Training target data.
|
|
104
|
+
val_data_target : Optional[Union[Path, str, np.ndarray]]
|
|
105
|
+
Validation target data.
|
|
106
|
+
val_percentage : float
|
|
107
|
+
Percentage of the training data to use for validation, if no validation data is
|
|
108
|
+
provided.
|
|
109
|
+
val_minimum_split : int
|
|
110
|
+
Minimum number of patches or files to split from the training data for
|
|
111
|
+
validation, if no validation data is provided.
|
|
112
|
+
read_source_func : Optional[Callable]
|
|
113
|
+
Function to read the source data, used if `data_type` is `custom`.
|
|
114
|
+
extension_filter : str
|
|
115
|
+
Filter for file extensions, used if `data_type` is `custom`.
|
|
56
116
|
"""
|
|
57
117
|
|
|
58
118
|
def __init__(
|
|
59
119
|
self,
|
|
60
|
-
data_config:
|
|
120
|
+
data_config: DataConfig,
|
|
61
121
|
train_data: Union[Path, str, np.ndarray],
|
|
62
122
|
val_data: Optional[Union[Path, str, np.ndarray]] = None,
|
|
63
123
|
train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
@@ -98,6 +158,8 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
98
158
|
val_minimum_split : int, optional
|
|
99
159
|
Minimum number of patches or files to split from the training data for
|
|
100
160
|
validation, by default 5. Only used if `val_data` is None.
|
|
161
|
+
use_in_memory : bool, optional
|
|
162
|
+
Use in memory dataset if possible, by default True.
|
|
101
163
|
|
|
102
164
|
Raises
|
|
103
165
|
------
|
|
@@ -128,25 +190,30 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
128
190
|
if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
129
191
|
raise ValueError(
|
|
130
192
|
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
131
|
-
f"specifying a `read_source_func`."
|
|
193
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
132
194
|
)
|
|
133
195
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
train_data, np.ndarray
|
|
196
|
+
# check correct input type
|
|
197
|
+
if (
|
|
198
|
+
isinstance(train_data, np.ndarray)
|
|
199
|
+
and data_config.data_type != SupportedData.ARRAY
|
|
137
200
|
):
|
|
138
201
|
raise ValueError(
|
|
139
|
-
f"
|
|
140
|
-
f"{
|
|
202
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
203
|
+
f"{data_config.data_type}. Set the data type in the configuration "
|
|
204
|
+
f"to {SupportedData.ARRAY} to train on numpy arrays."
|
|
141
205
|
)
|
|
142
206
|
|
|
143
207
|
# and that Path or str are passed, if tiff file type specified
|
|
144
|
-
elif
|
|
145
|
-
|
|
208
|
+
elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
|
|
209
|
+
data_config.data_type != SupportedData.TIFF
|
|
210
|
+
and data_config.data_type != SupportedData.CUSTOM
|
|
146
211
|
):
|
|
147
212
|
raise ValueError(
|
|
148
|
-
f"
|
|
149
|
-
f"
|
|
213
|
+
f"Received a path as input, but the data type was neither set to "
|
|
214
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
215
|
+
f"in the configuration to {SupportedData.TIFF} or "
|
|
216
|
+
f"{SupportedData.CUSTOM} to train on files."
|
|
150
217
|
)
|
|
151
218
|
|
|
152
219
|
# configuration
|
|
@@ -231,7 +298,15 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
231
298
|
validate_source_target_files(self.val_files, self.val_target_files)
|
|
232
299
|
|
|
233
300
|
def setup(self, *args: Any, **kwargs: Any) -> None:
|
|
234
|
-
"""Hook called at the beginning of fit, validate, or predict.
|
|
301
|
+
"""Hook called at the beginning of fit, validate, or predict.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
*args : Any
|
|
306
|
+
Unused.
|
|
307
|
+
**kwargs : Any
|
|
308
|
+
Unused.
|
|
309
|
+
"""
|
|
235
310
|
# if numpy array
|
|
236
311
|
if self.data_type == SupportedData.ARRAY:
|
|
237
312
|
# train dataset
|
|
@@ -353,9 +428,12 @@ class CAREamicsWood(L.LightningDataModule):
|
|
|
353
428
|
)
|
|
354
429
|
|
|
355
430
|
|
|
356
|
-
class
|
|
431
|
+
class TrainingDataWrapper(CAREamicsTrainData):
|
|
357
432
|
"""
|
|
358
|
-
|
|
433
|
+
Wrapper around the CAREamics Lightning training data module.
|
|
434
|
+
|
|
435
|
+
This class is used to explicitely pass the parameters usually contained in a
|
|
436
|
+
`data_model` configuration.
|
|
359
437
|
|
|
360
438
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
361
439
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
@@ -442,11 +520,11 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
442
520
|
|
|
443
521
|
Examples
|
|
444
522
|
--------
|
|
445
|
-
Create a
|
|
523
|
+
Create a TrainingDataWrapper with default transforms with a numpy array:
|
|
446
524
|
>>> import numpy as np
|
|
447
|
-
>>> from careamics import
|
|
525
|
+
>>> from careamics import TrainingDataWrapper
|
|
448
526
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
449
|
-
>>> data_module =
|
|
527
|
+
>>> data_module = TrainingDataWrapper(
|
|
450
528
|
... train_data=my_array,
|
|
451
529
|
... data_type="array",
|
|
452
530
|
... patch_size=(8, 8),
|
|
@@ -457,12 +535,12 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
457
535
|
For custom data types (those not supported by CAREamics), then one can pass a read
|
|
458
536
|
function and a filter for the files extension:
|
|
459
537
|
>>> import numpy as np
|
|
460
|
-
>>> from careamics import
|
|
538
|
+
>>> from careamics import TrainingDataWrapper
|
|
461
539
|
>>>
|
|
462
540
|
>>> def read_npy(path):
|
|
463
541
|
... return np.load(path)
|
|
464
542
|
>>>
|
|
465
|
-
>>> data_module =
|
|
543
|
+
>>> data_module = TrainingDataWrapper(
|
|
466
544
|
... train_data="path/to/data",
|
|
467
545
|
... data_type="custom",
|
|
468
546
|
... patch_size=(8, 8),
|
|
@@ -475,7 +553,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
475
553
|
If you want to use a different set of transformations, you can pass a list of
|
|
476
554
|
transforms:
|
|
477
555
|
>>> import numpy as np
|
|
478
|
-
>>> from careamics import
|
|
556
|
+
>>> from careamics import TrainingDataWrapper
|
|
479
557
|
>>> from careamics.config.support import SupportedTransform
|
|
480
558
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
481
559
|
>>> my_transforms = [
|
|
@@ -488,7 +566,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
488
566
|
... "name": SupportedTransform.N2V_MANIPULATE.value,
|
|
489
567
|
... }
|
|
490
568
|
... ]
|
|
491
|
-
>>> data_module =
|
|
569
|
+
>>> data_module = TrainingDataWrapper(
|
|
492
570
|
... train_data=my_array,
|
|
493
571
|
... data_type="array",
|
|
494
572
|
... patch_size=(8, 8),
|
|
@@ -628,7 +706,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
|
|
|
628
706
|
data_dict["transforms"] = transforms
|
|
629
707
|
|
|
630
708
|
# validate configuration
|
|
631
|
-
self.data_config =
|
|
709
|
+
self.data_config = DataConfig(**data_dict)
|
|
632
710
|
|
|
633
711
|
# N2V specific checks, N2V, structN2V, and transforms
|
|
634
712
|
if (
|
careamics/lightning_module.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
|
|
|
3
3
|
import pytorch_lightning as L
|
|
4
4
|
from torch import Tensor, nn
|
|
5
5
|
|
|
6
|
-
from careamics.config import
|
|
6
|
+
from careamics.config import AlgorithmConfig
|
|
7
7
|
from careamics.config.support import (
|
|
8
8
|
SupportedAlgorithm,
|
|
9
9
|
SupportedArchitecture,
|
|
@@ -17,7 +17,7 @@ from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
|
17
17
|
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class CAREamicsModule(L.LightningModule):
|
|
21
21
|
"""
|
|
22
22
|
CAREamics Lightning module.
|
|
23
23
|
|
|
@@ -38,7 +38,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
38
38
|
Learning rate scheduler name.
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
|
-
def __init__(self, algorithm_config: Union[
|
|
41
|
+
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
42
42
|
"""
|
|
43
43
|
CAREamics Lightning module.
|
|
44
44
|
|
|
@@ -53,7 +53,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
53
53
|
super().__init__()
|
|
54
54
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
55
55
|
if isinstance(algorithm_config, dict):
|
|
56
|
-
algorithm_config =
|
|
56
|
+
algorithm_config = AlgorithmConfig(**algorithm_config)
|
|
57
57
|
|
|
58
58
|
# create model and loss function
|
|
59
59
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
@@ -192,7 +192,7 @@ class CAREamicsKiln(L.LightningModule):
|
|
|
192
192
|
}
|
|
193
193
|
|
|
194
194
|
|
|
195
|
-
class CAREamicsModule
|
|
195
|
+
class CAREamicsModuleWrapper(CAREamicsModule):
|
|
196
196
|
"""Class defining the API for CAREamics Lightning layer.
|
|
197
197
|
|
|
198
198
|
This class exposes parameters used to create an AlgorithmModel instance, triggering
|
|
@@ -287,6 +287,6 @@ class CAREamicsModule(CAREamicsKiln):
|
|
|
287
287
|
algorithm_configuration["model"] = model_configuration
|
|
288
288
|
|
|
289
289
|
# call the parent init using an AlgorithmModel instance
|
|
290
|
-
super().__init__(
|
|
290
|
+
super().__init__(AlgorithmConfig(**algorithm_configuration))
|
|
291
291
|
|
|
292
292
|
# TODO add load_from_checkpoint wrapper
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Prediction Lightning data modules."""
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
3
4
|
|
|
@@ -7,7 +8,7 @@ from albumentations import Compose
|
|
|
7
8
|
from torch.utils.data import DataLoader
|
|
8
9
|
from torch.utils.data.dataloader import default_collate
|
|
9
10
|
|
|
10
|
-
from careamics.config import
|
|
11
|
+
from careamics.config import InferenceConfig
|
|
11
12
|
from careamics.config.support import SupportedData
|
|
12
13
|
from careamics.config.tile_information import TileInformation
|
|
13
14
|
from careamics.dataset.dataset_utils import (
|
|
@@ -62,9 +63,9 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
62
63
|
return default_collate(new_batch)
|
|
63
64
|
|
|
64
65
|
|
|
65
|
-
class
|
|
66
|
+
class CAREamicsPredictData(L.LightningDataModule):
|
|
66
67
|
"""
|
|
67
|
-
|
|
68
|
+
CAREamics Lightning prediction data module.
|
|
68
69
|
|
|
69
70
|
The data module can be used with Path, str or numpy arrays. The data can be either
|
|
70
71
|
a folder containing images or a single file.
|
|
@@ -79,7 +80,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
79
80
|
|
|
80
81
|
Parameters
|
|
81
82
|
----------
|
|
82
|
-
|
|
83
|
+
pred_config : InferenceModel
|
|
83
84
|
Pydantic model for CAREamics prediction configuration.
|
|
84
85
|
pred_data : Union[Path, str, np.ndarray]
|
|
85
86
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
@@ -93,7 +94,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
93
94
|
|
|
94
95
|
def __init__(
|
|
95
96
|
self,
|
|
96
|
-
|
|
97
|
+
pred_config: InferenceConfig,
|
|
97
98
|
pred_data: Union[Path, str, np.ndarray],
|
|
98
99
|
read_source_func: Optional[Callable] = None,
|
|
99
100
|
extension_filter: str = "",
|
|
@@ -115,7 +116,7 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
115
116
|
|
|
116
117
|
Parameters
|
|
117
118
|
----------
|
|
118
|
-
|
|
119
|
+
pred_config : InferenceModel
|
|
119
120
|
Pydantic model for CAREamics prediction configuration.
|
|
120
121
|
pred_data : Union[Path, str, np.ndarray]
|
|
121
122
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
@@ -142,51 +143,53 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
142
143
|
super().__init__()
|
|
143
144
|
|
|
144
145
|
# check that a read source function is provided for custom types
|
|
145
|
-
if
|
|
146
|
-
prediction_config.data_type == SupportedData.CUSTOM
|
|
147
|
-
and read_source_func is None
|
|
148
|
-
):
|
|
146
|
+
if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
149
147
|
raise ValueError(
|
|
150
148
|
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
151
|
-
f"specifying a `read_source_func`."
|
|
149
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
152
150
|
)
|
|
153
151
|
|
|
154
|
-
#
|
|
155
|
-
|
|
156
|
-
pred_data, np.ndarray
|
|
152
|
+
# check correct input type
|
|
153
|
+
if (
|
|
154
|
+
isinstance(pred_data, np.ndarray)
|
|
155
|
+
and pred_config.data_type != SupportedData.ARRAY
|
|
157
156
|
):
|
|
158
157
|
raise ValueError(
|
|
159
|
-
f"
|
|
160
|
-
f"{
|
|
158
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
159
|
+
f"{pred_config.data_type}. Set the data type "
|
|
160
|
+
f"to {SupportedData.ARRAY} to predict on numpy arrays."
|
|
161
161
|
)
|
|
162
162
|
|
|
163
163
|
# and that Path or str are passed, if tiff file type specified
|
|
164
|
-
elif
|
|
165
|
-
|
|
164
|
+
elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
|
|
165
|
+
pred_config.data_type != SupportedData.TIFF
|
|
166
|
+
and pred_config.data_type != SupportedData.CUSTOM
|
|
166
167
|
):
|
|
167
168
|
raise ValueError(
|
|
168
|
-
f"
|
|
169
|
-
f"
|
|
169
|
+
f"Received a path as input, but the data type was neither set to "
|
|
170
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
171
|
+
f" to {SupportedData.TIFF} or "
|
|
172
|
+
f"{SupportedData.CUSTOM} to predict on files."
|
|
170
173
|
)
|
|
171
174
|
|
|
172
175
|
# configuration data
|
|
173
|
-
self.prediction_config =
|
|
174
|
-
self.data_type =
|
|
175
|
-
self.batch_size =
|
|
176
|
+
self.prediction_config = pred_config
|
|
177
|
+
self.data_type = pred_config.data_type
|
|
178
|
+
self.batch_size = pred_config.batch_size
|
|
176
179
|
self.dataloader_params = dataloader_params
|
|
177
180
|
|
|
178
181
|
self.pred_data = pred_data
|
|
179
|
-
self.tile_size =
|
|
180
|
-
self.tile_overlap =
|
|
182
|
+
self.tile_size = pred_config.tile_size
|
|
183
|
+
self.tile_overlap = pred_config.tile_overlap
|
|
181
184
|
|
|
182
185
|
# read source function
|
|
183
|
-
if
|
|
186
|
+
if pred_config.data_type == SupportedData.CUSTOM:
|
|
184
187
|
# mypy check
|
|
185
188
|
assert read_source_func is not None
|
|
186
189
|
|
|
187
190
|
self.read_source_func: Callable = read_source_func
|
|
188
|
-
elif
|
|
189
|
-
self.read_source_func = get_read_func(
|
|
191
|
+
elif pred_config.data_type != SupportedData.ARRAY:
|
|
192
|
+
self.read_source_func = get_read_func(pred_config.data_type)
|
|
190
193
|
|
|
191
194
|
self.extension_filter = extension_filter
|
|
192
195
|
|
|
@@ -238,9 +241,12 @@ class CAREamicsClay(L.LightningDataModule):
|
|
|
238
241
|
) # TODO check workers are used
|
|
239
242
|
|
|
240
243
|
|
|
241
|
-
class
|
|
244
|
+
class PredictDataWrapper(CAREamicsPredictData):
|
|
242
245
|
"""
|
|
243
|
-
|
|
246
|
+
Wrapper around the CAREamics inference Lightning data module.
|
|
247
|
+
|
|
248
|
+
This class is used to explicitely pass the parameters usually contained in a
|
|
249
|
+
`inference_model` configuration.
|
|
244
250
|
|
|
245
251
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
246
252
|
parameters passed to the datamodule are consistent with the model's requirements
|
|
@@ -329,6 +335,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
329
335
|
Prediction data.
|
|
330
336
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
331
337
|
Data type, see `SupportedData` for available options.
|
|
338
|
+
mean : float
|
|
339
|
+
Mean value for normalization, only used if Normalization is defined in the
|
|
340
|
+
transforms.
|
|
341
|
+
std : float
|
|
342
|
+
Standard deviation value for normalization, only used if Normalization is
|
|
343
|
+
defined in the transform.
|
|
332
344
|
tile_size : List[int]
|
|
333
345
|
Tile size, 2D or 3D tile size.
|
|
334
346
|
tile_overlap : List[int]
|
|
@@ -339,12 +351,6 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
339
351
|
Batch size.
|
|
340
352
|
tta_transforms : bool, optional
|
|
341
353
|
Use test time augmentation, by default True.
|
|
342
|
-
mean : Optional[float], optional
|
|
343
|
-
Mean value for normalization, only used if Normalization is defined, by
|
|
344
|
-
default None.
|
|
345
|
-
std : Optional[float], optional
|
|
346
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
347
|
-
defined, by default None.
|
|
348
354
|
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
349
355
|
List of transforms to apply to prediction patches. If None, default
|
|
350
356
|
transforms are applied.
|
|
@@ -374,7 +380,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
374
380
|
prediction_dict["transforms"] = transforms
|
|
375
381
|
|
|
376
382
|
# validate configuration
|
|
377
|
-
self.prediction_config =
|
|
383
|
+
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
378
384
|
|
|
379
385
|
# sanity check on the dataloader parameters
|
|
380
386
|
if "batch_size" in dataloader_params:
|
|
@@ -382,7 +388,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
|
|
|
382
388
|
del dataloader_params["batch_size"]
|
|
383
389
|
|
|
384
390
|
super().__init__(
|
|
385
|
-
|
|
391
|
+
pred_config=self.prediction_config,
|
|
386
392
|
pred_data=pred_data,
|
|
387
393
|
read_source_func=read_source_func,
|
|
388
394
|
extension_filter=extension_filter,
|
|
@@ -26,14 +26,14 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
26
26
|
WeightsDescr,
|
|
27
27
|
)
|
|
28
28
|
|
|
29
|
-
from careamics.config import Configuration,
|
|
29
|
+
from careamics.config import Configuration, DataConfig
|
|
30
30
|
|
|
31
31
|
from ._readme_factory import readme_factory
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def _create_axes(
|
|
35
35
|
array: np.ndarray,
|
|
36
|
-
data_config:
|
|
36
|
+
data_config: DataConfig,
|
|
37
37
|
channel_names: Optional[List[str]] = None,
|
|
38
38
|
is_input: bool = True,
|
|
39
39
|
) -> List[AxisBase]:
|
|
@@ -100,7 +100,7 @@ def _create_axes(
|
|
|
100
100
|
def _create_inputs_ouputs(
|
|
101
101
|
input_array: np.ndarray,
|
|
102
102
|
output_array: np.ndarray,
|
|
103
|
-
data_config:
|
|
103
|
+
data_config: DataConfig,
|
|
104
104
|
input_path: Union[Path, str],
|
|
105
105
|
output_path: Union[Path, str],
|
|
106
106
|
channel_names: Optional[List[str]] = None,
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -11,7 +11,7 @@ from torch import __version__, load, save
|
|
|
11
11
|
|
|
12
12
|
from careamics.config import Configuration, load_configuration, save_configuration
|
|
13
13
|
from careamics.config.support import SupportedArchitecture
|
|
14
|
-
from careamics.lightning_module import
|
|
14
|
+
from careamics.lightning_module import CAREamicsModule
|
|
15
15
|
|
|
16
16
|
from .bioimage import (
|
|
17
17
|
create_env_text,
|
|
@@ -21,7 +21,7 @@ from .bioimage import (
|
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def _export_state_dict(model:
|
|
24
|
+
def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
|
|
25
25
|
"""
|
|
26
26
|
Export the model state dictionary to a file.
|
|
27
27
|
|
|
@@ -51,7 +51,7 @@ def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
|
|
|
51
51
|
return path
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def _load_state_dict(model:
|
|
54
|
+
def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
|
|
55
55
|
"""
|
|
56
56
|
Load a model from a state dictionary.
|
|
57
57
|
|
|
@@ -73,7 +73,7 @@ def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
|
|
|
73
73
|
|
|
74
74
|
# TODO break down in subfunctions
|
|
75
75
|
def export_to_bmz(
|
|
76
|
-
model:
|
|
76
|
+
model: CAREamicsModule,
|
|
77
77
|
config: Configuration,
|
|
78
78
|
path: Union[Path, str],
|
|
79
79
|
name: str,
|
|
@@ -185,7 +185,7 @@ def export_to_bmz(
|
|
|
185
185
|
save_bioimageio_package(model_description, output_path=path)
|
|
186
186
|
|
|
187
187
|
|
|
188
|
-
def load_from_bmz(path: Union[Path, str]) -> Tuple[
|
|
188
|
+
def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
189
189
|
"""Load a model from a BioImage Model Zoo archive.
|
|
190
190
|
|
|
191
191
|
Parameters
|
|
@@ -223,7 +223,7 @@ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]
|
|
|
223
223
|
config = load_configuration(config_path)
|
|
224
224
|
|
|
225
225
|
# create careamics lightning module
|
|
226
|
-
model =
|
|
226
|
+
model = CAREamicsModule(algorithm_config=config.algorithm_config)
|
|
227
227
|
|
|
228
228
|
# load model state dictionary
|
|
229
229
|
_load_state_dict(model, weights_path)
|
|
@@ -6,12 +6,12 @@ from typing import Tuple, Union
|
|
|
6
6
|
from torch import load
|
|
7
7
|
|
|
8
8
|
from careamics.config import Configuration
|
|
9
|
-
from careamics.lightning_module import
|
|
9
|
+
from careamics.lightning_module import CAREamicsModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def load_pretrained(path: Union[Path, str]) -> Tuple[
|
|
14
|
+
def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
15
15
|
"""
|
|
16
16
|
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
17
17
|
|
|
@@ -44,7 +44,7 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def _load_checkpoint(path: Union[Path, str]) -> Tuple[
|
|
47
|
+
def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
48
48
|
"""
|
|
49
49
|
Load a model from a checkpoint and return both model and configuration.
|
|
50
50
|
|
|
@@ -75,6 +75,6 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configurati
|
|
|
75
75
|
f"checkpoint: {checkpoint.keys()}"
|
|
76
76
|
) from e
|
|
77
77
|
|
|
78
|
-
model =
|
|
78
|
+
model = CAREamicsModule.load_from_checkpoint(path)
|
|
79
79
|
|
|
80
80
|
return model, Configuration(**cfg_dict)
|