careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +83 -62
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -0
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +2 -0
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +1 -79
- careamics/config/configuration_model.py +12 -7
- careamics/config/data_model.py +29 -10
- careamics/config/inference_model.py +12 -2
- careamics/config/optimizer_models.py +6 -0
- careamics/config/support/supported_data.py +29 -4
- careamics/config/tile_information.py +10 -0
- careamics/config/training_model.py +5 -1
- careamics/dataset/dataset_utils/__init__.py +0 -6
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
- careamics/dataset/in_memory_dataset.py +37 -21
- careamics/dataset/iterable_dataset.py +38 -34
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/patching.py +53 -37
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +1 -1
- careamics/prediction_utils/__init__.py +0 -2
- careamics/prediction_utils/prediction_outputs.py +18 -46
- careamics/prediction_utils/stitch_prediction.py +17 -14
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
- careamics/config/configuration_example.py +0 -86
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/prediction_utils/create_pred_datamodule.py +0 -185
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,9 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Callable,
|
|
5
|
+
from typing import Callable, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
|
|
9
10
|
from ...utils.logging import get_logger
|
|
10
11
|
from ..dataset_utils import reshape_array
|
|
@@ -18,34 +19,49 @@ logger = get_logger(__name__)
|
|
|
18
19
|
class Stats:
|
|
19
20
|
"""Dataclass to store statistics."""
|
|
20
21
|
|
|
21
|
-
means: Union[
|
|
22
|
-
|
|
22
|
+
means: Union[NDArray, tuple, list, None]
|
|
23
|
+
"""Mean of the data across channels."""
|
|
23
24
|
|
|
25
|
+
stds: Union[NDArray, tuple, list, None]
|
|
26
|
+
"""Standard deviation of the data across channels."""
|
|
24
27
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
"""Dataclass to store patches and statistics."""
|
|
28
|
+
def get_statistics(self) -> tuple[list[float], list[float]]:
|
|
29
|
+
"""Return the means and standard deviations.
|
|
28
30
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
tuple of two lists of floats
|
|
34
|
+
Means and standard deviations.
|
|
35
|
+
"""
|
|
36
|
+
if self.means is None or self.stds is None:
|
|
37
|
+
return [], []
|
|
38
|
+
|
|
39
|
+
return list(self.means), list(self.stds)
|
|
33
40
|
|
|
34
41
|
|
|
35
42
|
@dataclass
|
|
36
|
-
class
|
|
43
|
+
class PatchedOutput:
|
|
37
44
|
"""Dataclass to store patches and statistics."""
|
|
38
45
|
|
|
46
|
+
patches: Union[NDArray]
|
|
47
|
+
"""Image patches."""
|
|
48
|
+
|
|
49
|
+
targets: Union[NDArray, None]
|
|
50
|
+
"""Target patches."""
|
|
51
|
+
|
|
39
52
|
image_stats: Stats
|
|
53
|
+
"""Statistics of the image patches."""
|
|
54
|
+
|
|
40
55
|
target_stats: Stats
|
|
56
|
+
"""Statistics of the target patches."""
|
|
41
57
|
|
|
42
58
|
|
|
43
59
|
# called by in memory dataset
|
|
44
60
|
def prepare_patches_supervised(
|
|
45
|
-
train_files:
|
|
46
|
-
target_files:
|
|
61
|
+
train_files: list[Path],
|
|
62
|
+
target_files: list[Path],
|
|
47
63
|
axes: str,
|
|
48
|
-
patch_size: Union[
|
|
64
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
49
65
|
read_source_func: Callable,
|
|
50
66
|
) -> PatchedOutput:
|
|
51
67
|
"""
|
|
@@ -55,13 +71,13 @@ def prepare_patches_supervised(
|
|
|
55
71
|
|
|
56
72
|
Parameters
|
|
57
73
|
----------
|
|
58
|
-
train_files :
|
|
74
|
+
train_files : list of pathlib.Path
|
|
59
75
|
List of paths to training data.
|
|
60
|
-
target_files :
|
|
76
|
+
target_files : list of pathlib.Path
|
|
61
77
|
List of paths to target data.
|
|
62
78
|
axes : str
|
|
63
79
|
Axes of the data.
|
|
64
|
-
patch_size :
|
|
80
|
+
patch_size : list or tuple of int
|
|
65
81
|
Size of the patches.
|
|
66
82
|
read_source_func : Callable
|
|
67
83
|
Function to read the data.
|
|
@@ -127,9 +143,9 @@ def prepare_patches_supervised(
|
|
|
127
143
|
|
|
128
144
|
# called by in_memory_dataset
|
|
129
145
|
def prepare_patches_unsupervised(
|
|
130
|
-
train_files:
|
|
146
|
+
train_files: list[Path],
|
|
131
147
|
axes: str,
|
|
132
|
-
patch_size: Union[
|
|
148
|
+
patch_size: Union[list[int], tuple[int]],
|
|
133
149
|
read_source_func: Callable,
|
|
134
150
|
) -> PatchedOutput:
|
|
135
151
|
"""Iterate over data source and create an array of patches.
|
|
@@ -138,19 +154,19 @@ def prepare_patches_unsupervised(
|
|
|
138
154
|
|
|
139
155
|
Parameters
|
|
140
156
|
----------
|
|
141
|
-
train_files :
|
|
157
|
+
train_files : list of pathlib.Path
|
|
142
158
|
List of paths to training data.
|
|
143
159
|
axes : str
|
|
144
160
|
Axes of the data.
|
|
145
|
-
patch_size :
|
|
161
|
+
patch_size : list or tuple of int
|
|
146
162
|
Size of the patches.
|
|
147
163
|
read_source_func : Callable
|
|
148
164
|
Function to read the data.
|
|
149
165
|
|
|
150
166
|
Returns
|
|
151
167
|
-------
|
|
152
|
-
|
|
153
|
-
|
|
168
|
+
PatchedOutput
|
|
169
|
+
Dataclass holding patches and their statistics.
|
|
154
170
|
"""
|
|
155
171
|
means, stds, num_samples = 0, 0, 0
|
|
156
172
|
all_patches = []
|
|
@@ -189,10 +205,10 @@ def prepare_patches_unsupervised(
|
|
|
189
205
|
|
|
190
206
|
# called on arrays by in memory dataset
|
|
191
207
|
def prepare_patches_supervised_array(
|
|
192
|
-
data:
|
|
208
|
+
data: NDArray,
|
|
193
209
|
axes: str,
|
|
194
|
-
data_target:
|
|
195
|
-
patch_size: Union[
|
|
210
|
+
data_target: NDArray,
|
|
211
|
+
patch_size: Union[list[int], tuple[int]],
|
|
196
212
|
) -> PatchedOutput:
|
|
197
213
|
"""Iterate over data source and create an array of patches.
|
|
198
214
|
|
|
@@ -203,19 +219,19 @@ def prepare_patches_supervised_array(
|
|
|
203
219
|
|
|
204
220
|
Parameters
|
|
205
221
|
----------
|
|
206
|
-
data :
|
|
222
|
+
data : numpy.ndarray
|
|
207
223
|
Input data array.
|
|
208
224
|
axes : str
|
|
209
225
|
Axes of the data.
|
|
210
|
-
data_target :
|
|
226
|
+
data_target : numpy.ndarray
|
|
211
227
|
Target data array.
|
|
212
|
-
patch_size :
|
|
228
|
+
patch_size : list or tuple of int
|
|
213
229
|
Size of the patches.
|
|
214
230
|
|
|
215
231
|
Returns
|
|
216
232
|
-------
|
|
217
|
-
|
|
218
|
-
|
|
233
|
+
PatchedOutput
|
|
234
|
+
Dataclass holding the source and target patches, with their statistics.
|
|
219
235
|
"""
|
|
220
236
|
# reshape array
|
|
221
237
|
reshaped_sample = reshape_array(data, axes)
|
|
@@ -245,9 +261,9 @@ def prepare_patches_supervised_array(
|
|
|
245
261
|
|
|
246
262
|
# called by in memory dataset
|
|
247
263
|
def prepare_patches_unsupervised_array(
|
|
248
|
-
data:
|
|
264
|
+
data: NDArray,
|
|
249
265
|
axes: str,
|
|
250
|
-
patch_size: Union[
|
|
266
|
+
patch_size: Union[list[int], tuple[int]],
|
|
251
267
|
) -> PatchedOutput:
|
|
252
268
|
"""
|
|
253
269
|
Iterate over data source and create an array of patches.
|
|
@@ -259,17 +275,17 @@ def prepare_patches_unsupervised_array(
|
|
|
259
275
|
|
|
260
276
|
Parameters
|
|
261
277
|
----------
|
|
262
|
-
data :
|
|
278
|
+
data : numpy.ndarray
|
|
263
279
|
Input data array.
|
|
264
280
|
axes : str
|
|
265
281
|
Axes of the data.
|
|
266
|
-
patch_size :
|
|
282
|
+
patch_size : list or tuple of int
|
|
267
283
|
Size of the patches.
|
|
268
284
|
|
|
269
285
|
Returns
|
|
270
286
|
-------
|
|
271
|
-
|
|
272
|
-
|
|
287
|
+
PatchedOutput
|
|
288
|
+
Dataclass holding the patches and their statistics.
|
|
273
289
|
"""
|
|
274
290
|
# reshape array
|
|
275
291
|
reshaped_sample = reshape_array(data, axes)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Module to get read functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Dict, Protocol, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import read_tiff
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# This is very strict, function signature has to match including arg names
|
|
14
|
+
# See WriteFunc notes
|
|
15
|
+
class ReadFunc(Protocol):
|
|
16
|
+
"""Protocol for type hinting read functions."""
|
|
17
|
+
|
|
18
|
+
def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
|
|
19
|
+
"""
|
|
20
|
+
Type hinted callables must match this function signature (not including self).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
file_path : pathlib.Path
|
|
25
|
+
Path to file.
|
|
26
|
+
*args
|
|
27
|
+
Other positional arguments.
|
|
28
|
+
**kwargs
|
|
29
|
+
Other keyword arguments.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
READ_FUNCS: Dict[SupportedData, ReadFunc] = {
|
|
34
|
+
SupportedData.TIFF: read_tiff,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
|
|
39
|
+
"""
|
|
40
|
+
Get the read function for the data type.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
data_type : SupportedData
|
|
45
|
+
Data type.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
callable
|
|
50
|
+
Read function.
|
|
51
|
+
"""
|
|
52
|
+
if data_type in READ_FUNCS:
|
|
53
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
54
|
+
return READ_FUNCS[data_type]
|
|
55
|
+
else:
|
|
56
|
+
raise NotImplementedError(f"Data type '{data_type}' is not supported.")
|
|
@@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
44
44
|
ValueError
|
|
45
45
|
If the axes length is incorrect.
|
|
46
46
|
"""
|
|
47
|
-
if fnmatch(
|
|
47
|
+
if fnmatch(
|
|
48
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
49
|
+
):
|
|
48
50
|
try:
|
|
49
51
|
array = tifffile.imread(file_path)
|
|
50
52
|
except (ValueError, OSError) as e:
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Module to get write functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Protocol, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import write_tiff
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# This is very strict, arguments have to be called file_path & img
|
|
14
|
+
# Alternative? - doesn't capture *args & **kwargs
|
|
15
|
+
# WriteFunc = Callable[[Path, NDArray], None]
|
|
16
|
+
class WriteFunc(Protocol):
|
|
17
|
+
"""Protocol for type hinting write functions."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Type hinted callables must match this function signature (not including self).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
file_path : pathlib.Path
|
|
26
|
+
Path to file.
|
|
27
|
+
img : numpy.ndarray
|
|
28
|
+
Image data to save.
|
|
29
|
+
*args
|
|
30
|
+
Other positional arguments.
|
|
31
|
+
**kwargs
|
|
32
|
+
Other keyword arguments.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
|
|
37
|
+
SupportedData.TIFF: write_tiff,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc:
|
|
42
|
+
"""
|
|
43
|
+
Get the write function for the data type.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
data_type : SupportedData
|
|
48
|
+
Data type.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
callable
|
|
53
|
+
Write function.
|
|
54
|
+
"""
|
|
55
|
+
if data_type in WRITE_FUNCS:
|
|
56
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
57
|
+
return WRITE_FUNCS[data_type]
|
|
58
|
+
else:
|
|
59
|
+
raise NotImplementedError(f"Data type {data_type} is not supported.")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Write tiff function."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tifffile
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Write tiff files.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
file_path : pathlib.Path
|
|
19
|
+
Path to file.
|
|
20
|
+
img : numpy.ndarray
|
|
21
|
+
Image data to save.
|
|
22
|
+
*args
|
|
23
|
+
Positional arguments passed to `tifffile.imwrite`.
|
|
24
|
+
**kwargs
|
|
25
|
+
Keyword arguments passed to `tifffile.imwrite`.
|
|
26
|
+
|
|
27
|
+
Raises
|
|
28
|
+
------
|
|
29
|
+
ValueError
|
|
30
|
+
When the file extension of `file_path` does not match the Unix shell-style
|
|
31
|
+
pattern '*.tif*'.
|
|
32
|
+
"""
|
|
33
|
+
if not fnmatch(
|
|
34
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
35
|
+
):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
|
|
38
|
+
)
|
|
39
|
+
tifffile.imwrite(file_path, img, *args, **kwargs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""CAREamics PyTorch Lightning modules."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CAREamicsModule",
|
|
5
|
+
"create_careamics_module",
|
|
6
|
+
"TrainDataModule",
|
|
7
|
+
"create_train_datamodule",
|
|
8
|
+
"PredictDataModule",
|
|
9
|
+
"create_predict_datamodule",
|
|
10
|
+
"HyperParametersCallback",
|
|
11
|
+
"ProgressBarCallback",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
15
|
+
from .lightning_module import CAREamicsModule, create_careamics_module
|
|
16
|
+
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
17
|
+
from .train_data_module import TrainDataModule, create_train_datamodule
|
|
@@ -23,19 +23,19 @@ class CAREamicsModule(L.LightningModule):
|
|
|
23
23
|
"""
|
|
24
24
|
CAREamics Lightning module.
|
|
25
25
|
|
|
26
|
-
This class encapsulates the
|
|
26
|
+
This class encapsulates the PyTorch model along with the training, validation,
|
|
27
27
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
30
30
|
----------
|
|
31
|
-
algorithm_config :
|
|
31
|
+
algorithm_config : AlgorithmModel or dict
|
|
32
32
|
Algorithm configuration.
|
|
33
33
|
|
|
34
34
|
Attributes
|
|
35
35
|
----------
|
|
36
|
-
model : nn.Module
|
|
36
|
+
model : torch.nn.Module
|
|
37
37
|
PyTorch model.
|
|
38
|
-
loss_func : nn.Module
|
|
38
|
+
loss_func : torch.nn.Module
|
|
39
39
|
Loss function.
|
|
40
40
|
optimizer_name : str
|
|
41
41
|
Optimizer name.
|
|
@@ -53,7 +53,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
53
53
|
|
|
54
54
|
Parameters
|
|
55
55
|
----------
|
|
56
|
-
algorithm_config :
|
|
56
|
+
algorithm_config : AlgorithmModel or dict
|
|
57
57
|
Algorithm configuration.
|
|
58
58
|
"""
|
|
59
59
|
super().__init__()
|
|
@@ -91,7 +91,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
91
91
|
|
|
92
92
|
Parameters
|
|
93
93
|
----------
|
|
94
|
-
batch : Tensor
|
|
94
|
+
batch : torch.Tensor
|
|
95
95
|
Input batch.
|
|
96
96
|
batch_idx : Any
|
|
97
97
|
Batch index.
|
|
@@ -114,7 +114,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
114
114
|
|
|
115
115
|
Parameters
|
|
116
116
|
----------
|
|
117
|
-
batch : Tensor
|
|
117
|
+
batch : torch.Tensor
|
|
118
118
|
Input batch.
|
|
119
119
|
batch_idx : Any
|
|
120
120
|
Batch index.
|
|
@@ -138,7 +138,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
138
138
|
|
|
139
139
|
Parameters
|
|
140
140
|
----------
|
|
141
|
-
batch : Tensor
|
|
141
|
+
batch : torch.Tensor
|
|
142
142
|
Input batch.
|
|
143
143
|
batch_idx : Any
|
|
144
144
|
Batch index.
|
|
@@ -202,101 +202,74 @@ class CAREamicsModule(L.LightningModule):
|
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
|
|
206
|
-
|
|
205
|
+
def create_careamics_module(
|
|
206
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
207
|
+
loss: Union[SupportedLoss, str],
|
|
208
|
+
architecture: Union[SupportedArchitecture, str],
|
|
209
|
+
model_parameters: Optional[dict] = None,
|
|
210
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
211
|
+
optimizer_parameters: Optional[dict] = None,
|
|
212
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
213
|
+
lr_scheduler_parameters: Optional[dict] = None,
|
|
214
|
+
) -> CAREamicsModule:
|
|
215
|
+
"""Create a CAREamics Lithgning module.
|
|
207
216
|
|
|
208
|
-
This
|
|
209
|
-
parameters validation.
|
|
217
|
+
This function exposes parameters used to create an AlgorithmModel instance,
|
|
218
|
+
triggering parameters validation.
|
|
210
219
|
|
|
211
220
|
Parameters
|
|
212
221
|
----------
|
|
213
|
-
algorithm :
|
|
222
|
+
algorithm : SupportedAlgorithm or str
|
|
214
223
|
Algorithm to use for training (see SupportedAlgorithm).
|
|
215
|
-
loss :
|
|
224
|
+
loss : SupportedLoss or str
|
|
216
225
|
Loss function to use for training (see SupportedLoss).
|
|
217
|
-
architecture :
|
|
226
|
+
architecture : SupportedArchitecture or str
|
|
218
227
|
Model architecture to use for training (see SupportedArchitecture).
|
|
219
228
|
model_parameters : dict, optional
|
|
220
229
|
Model parameters to use for training, by default {}. Model parameters are
|
|
221
230
|
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
222
231
|
`careamics.config.architectures`).
|
|
223
|
-
optimizer :
|
|
232
|
+
optimizer : SupportedOptimizer or str, optional
|
|
224
233
|
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
225
234
|
optimizer_parameters : dict, optional
|
|
226
235
|
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
227
236
|
default {}.
|
|
228
|
-
lr_scheduler :
|
|
237
|
+
lr_scheduler : SupportedScheduler or str, optional
|
|
229
238
|
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
230
239
|
(see SupportedScheduler).
|
|
231
240
|
lr_scheduler_parameters : dict, optional
|
|
232
241
|
Learning rate scheduler parameters to use for training, as defined in
|
|
233
242
|
`torch.optim`, by default {}.
|
|
234
|
-
"""
|
|
235
|
-
|
|
236
|
-
def __init__(
|
|
237
|
-
self,
|
|
238
|
-
algorithm: Union[SupportedAlgorithm, str],
|
|
239
|
-
loss: Union[SupportedLoss, str],
|
|
240
|
-
architecture: Union[SupportedArchitecture, str],
|
|
241
|
-
model_parameters: Optional[dict] = None,
|
|
242
|
-
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
243
|
-
optimizer_parameters: Optional[dict] = None,
|
|
244
|
-
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
245
|
-
lr_scheduler_parameters: Optional[dict] = None,
|
|
246
|
-
) -> None:
|
|
247
|
-
"""
|
|
248
|
-
Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
|
|
249
243
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
"algorithm": algorithm,
|
|
283
|
-
"loss": loss,
|
|
284
|
-
"optimizer": {
|
|
285
|
-
"name": optimizer,
|
|
286
|
-
"parameters": optimizer_parameters,
|
|
287
|
-
},
|
|
288
|
-
"lr_scheduler": {
|
|
289
|
-
"name": lr_scheduler,
|
|
290
|
-
"parameters": lr_scheduler_parameters,
|
|
291
|
-
},
|
|
292
|
-
}
|
|
293
|
-
model_configuration = {"architecture": architecture}
|
|
294
|
-
model_configuration.update(model_parameters)
|
|
295
|
-
|
|
296
|
-
# add model parameters to algorithm configuration
|
|
297
|
-
algorithm_configuration["model"] = model_configuration
|
|
298
|
-
|
|
299
|
-
# call the parent init using an AlgorithmModel instance
|
|
300
|
-
super().__init__(AlgorithmConfig(**algorithm_configuration))
|
|
301
|
-
|
|
302
|
-
# TODO add load_from_checkpoint wrapper
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
CAREamicsModule
|
|
247
|
+
CAREamics Lightning module.
|
|
248
|
+
"""
|
|
249
|
+
# create a AlgorithmModel compatible dictionary
|
|
250
|
+
if lr_scheduler_parameters is None:
|
|
251
|
+
lr_scheduler_parameters = {}
|
|
252
|
+
if optimizer_parameters is None:
|
|
253
|
+
optimizer_parameters = {}
|
|
254
|
+
if model_parameters is None:
|
|
255
|
+
model_parameters = {}
|
|
256
|
+
algorithm_configuration = {
|
|
257
|
+
"algorithm": algorithm,
|
|
258
|
+
"loss": loss,
|
|
259
|
+
"optimizer": {
|
|
260
|
+
"name": optimizer,
|
|
261
|
+
"parameters": optimizer_parameters,
|
|
262
|
+
},
|
|
263
|
+
"lr_scheduler": {
|
|
264
|
+
"name": lr_scheduler,
|
|
265
|
+
"parameters": lr_scheduler_parameters,
|
|
266
|
+
},
|
|
267
|
+
}
|
|
268
|
+
model_configuration = {"architecture": architecture}
|
|
269
|
+
model_configuration.update(model_parameters)
|
|
270
|
+
|
|
271
|
+
# add model parameters to algorithm configuration
|
|
272
|
+
algorithm_configuration["model"] = model_configuration
|
|
273
|
+
|
|
274
|
+
# call the parent init using an AlgorithmModel instance
|
|
275
|
+
return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))
|