careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +83 -62
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -0
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +2 -0
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +1 -79
  11. careamics/config/configuration_model.py +12 -7
  12. careamics/config/data_model.py +29 -10
  13. careamics/config/inference_model.py +12 -2
  14. careamics/config/optimizer_models.py +6 -0
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/tile_information.py +10 -0
  17. careamics/config/training_model.py +5 -1
  18. careamics/dataset/dataset_utils/__init__.py +0 -6
  19. careamics/dataset/dataset_utils/file_utils.py +1 -1
  20. careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
  21. careamics/dataset/in_memory_dataset.py +37 -21
  22. careamics/dataset/iterable_dataset.py +38 -34
  23. careamics/dataset/iterable_pred_dataset.py +2 -1
  24. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  25. careamics/dataset/patching/patching.py +53 -37
  26. careamics/file_io/__init__.py +7 -0
  27. careamics/file_io/read/__init__.py +11 -0
  28. careamics/file_io/read/get_func.py +56 -0
  29. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
  30. careamics/file_io/write/__init__.py +9 -0
  31. careamics/file_io/write/get_func.py +59 -0
  32. careamics/file_io/write/tiff.py +39 -0
  33. careamics/lightning/__init__.py +17 -0
  34. careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
  35. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
  36. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
  37. careamics/model_io/bmz_io.py +1 -1
  38. careamics/model_io/model_io_utils.py +1 -1
  39. careamics/prediction_utils/__init__.py +0 -2
  40. careamics/prediction_utils/prediction_outputs.py +18 -46
  41. careamics/prediction_utils/stitch_prediction.py +17 -14
  42. careamics/utils/__init__.py +2 -0
  43. careamics/utils/autocorrelation.py +40 -0
  44. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
  45. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
  46. careamics/config/configuration_example.py +0 -86
  47. careamics/dataset/dataset_utils/read_utils.py +0 -27
  48. careamics/prediction_utils/create_pred_datamodule.py +0 -185
  49. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  50. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  51. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  52. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  53. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
  54. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -2,9 +2,10 @@
2
2
 
3
3
  from dataclasses import dataclass
4
4
  from pathlib import Path
5
- from typing import Callable, List, Tuple, Union
5
+ from typing import Callable, Union
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
8
9
 
9
10
  from ...utils.logging import get_logger
10
11
  from ..dataset_utils import reshape_array
@@ -18,34 +19,49 @@ logger = get_logger(__name__)
18
19
  class Stats:
19
20
  """Dataclass to store statistics."""
20
21
 
21
- means: Union[np.ndarray, tuple, list, None]
22
- stds: Union[np.ndarray, tuple, list, None]
22
+ means: Union[NDArray, tuple, list, None]
23
+ """Mean of the data across channels."""
23
24
 
25
+ stds: Union[NDArray, tuple, list, None]
26
+ """Standard deviation of the data across channels."""
24
27
 
25
- @dataclass
26
- class PatchedOutput:
27
- """Dataclass to store patches and statistics."""
28
+ def get_statistics(self) -> tuple[list[float], list[float]]:
29
+ """Return the means and standard deviations.
28
30
 
29
- patches: Union[np.ndarray]
30
- targets: Union[np.ndarray, None]
31
- image_stats: Stats
32
- target_stats: Stats
31
+ Returns
32
+ -------
33
+ tuple of two lists of floats
34
+ Means and standard deviations.
35
+ """
36
+ if self.means is None or self.stds is None:
37
+ return [], []
38
+
39
+ return list(self.means), list(self.stds)
33
40
 
34
41
 
35
42
  @dataclass
36
- class StatsOutput:
43
+ class PatchedOutput:
37
44
  """Dataclass to store patches and statistics."""
38
45
 
46
+ patches: Union[NDArray]
47
+ """Image patches."""
48
+
49
+ targets: Union[NDArray, None]
50
+ """Target patches."""
51
+
39
52
  image_stats: Stats
53
+ """Statistics of the image patches."""
54
+
40
55
  target_stats: Stats
56
+ """Statistics of the target patches."""
41
57
 
42
58
 
43
59
  # called by in memory dataset
44
60
  def prepare_patches_supervised(
45
- train_files: List[Path],
46
- target_files: List[Path],
61
+ train_files: list[Path],
62
+ target_files: list[Path],
47
63
  axes: str,
48
- patch_size: Union[List[int], Tuple[int, ...]],
64
+ patch_size: Union[list[int], tuple[int, ...]],
49
65
  read_source_func: Callable,
50
66
  ) -> PatchedOutput:
51
67
  """
@@ -55,13 +71,13 @@ def prepare_patches_supervised(
55
71
 
56
72
  Parameters
57
73
  ----------
58
- train_files : List[Path]
74
+ train_files : list of pathlib.Path
59
75
  List of paths to training data.
60
- target_files : List[Path]
76
+ target_files : list of pathlib.Path
61
77
  List of paths to target data.
62
78
  axes : str
63
79
  Axes of the data.
64
- patch_size : Union[List[int], Tuple[int]]
80
+ patch_size : list or tuple of int
65
81
  Size of the patches.
66
82
  read_source_func : Callable
67
83
  Function to read the data.
@@ -127,9 +143,9 @@ def prepare_patches_supervised(
127
143
 
128
144
  # called by in_memory_dataset
129
145
  def prepare_patches_unsupervised(
130
- train_files: List[Path],
146
+ train_files: list[Path],
131
147
  axes: str,
132
- patch_size: Union[List[int], Tuple[int]],
148
+ patch_size: Union[list[int], tuple[int]],
133
149
  read_source_func: Callable,
134
150
  ) -> PatchedOutput:
135
151
  """Iterate over data source and create an array of patches.
@@ -138,19 +154,19 @@ def prepare_patches_unsupervised(
138
154
 
139
155
  Parameters
140
156
  ----------
141
- train_files : List[Path]
157
+ train_files : list of pathlib.Path
142
158
  List of paths to training data.
143
159
  axes : str
144
160
  Axes of the data.
145
- patch_size : Union[List[int], Tuple[int]]
161
+ patch_size : list or tuple of int
146
162
  Size of the patches.
147
163
  read_source_func : Callable
148
164
  Function to read the data.
149
165
 
150
166
  Returns
151
167
  -------
152
- Tuple[np.ndarray, None, float, float]
153
- Source and target patches, mean and standard deviation.
168
+ PatchedOutput
169
+ Dataclass holding patches and their statistics.
154
170
  """
155
171
  means, stds, num_samples = 0, 0, 0
156
172
  all_patches = []
@@ -189,10 +205,10 @@ def prepare_patches_unsupervised(
189
205
 
190
206
  # called on arrays by in memory dataset
191
207
  def prepare_patches_supervised_array(
192
- data: np.ndarray,
208
+ data: NDArray,
193
209
  axes: str,
194
- data_target: np.ndarray,
195
- patch_size: Union[List[int], Tuple[int]],
210
+ data_target: NDArray,
211
+ patch_size: Union[list[int], tuple[int]],
196
212
  ) -> PatchedOutput:
197
213
  """Iterate over data source and create an array of patches.
198
214
 
@@ -203,19 +219,19 @@ def prepare_patches_supervised_array(
203
219
 
204
220
  Parameters
205
221
  ----------
206
- data : np.ndarray
222
+ data : numpy.ndarray
207
223
  Input data array.
208
224
  axes : str
209
225
  Axes of the data.
210
- data_target : np.ndarray
226
+ data_target : numpy.ndarray
211
227
  Target data array.
212
- patch_size : Union[List[int], Tuple[int]]
228
+ patch_size : list or tuple of int
213
229
  Size of the patches.
214
230
 
215
231
  Returns
216
232
  -------
217
- Tuple[np.ndarray, np.ndarray, float, float]
218
- Source and target patches, mean and standard deviation.
233
+ PatchedOutput
234
+ Dataclass holding the source and target patches, with their statistics.
219
235
  """
220
236
  # reshape array
221
237
  reshaped_sample = reshape_array(data, axes)
@@ -245,9 +261,9 @@ def prepare_patches_supervised_array(
245
261
 
246
262
  # called by in memory dataset
247
263
  def prepare_patches_unsupervised_array(
248
- data: np.ndarray,
264
+ data: NDArray,
249
265
  axes: str,
250
- patch_size: Union[List[int], Tuple[int]],
266
+ patch_size: Union[list[int], tuple[int]],
251
267
  ) -> PatchedOutput:
252
268
  """
253
269
  Iterate over data source and create an array of patches.
@@ -259,17 +275,17 @@ def prepare_patches_unsupervised_array(
259
275
 
260
276
  Parameters
261
277
  ----------
262
- data : np.ndarray
278
+ data : numpy.ndarray
263
279
  Input data array.
264
280
  axes : str
265
281
  Axes of the data.
266
- patch_size : Union[List[int], Tuple[int]]
282
+ patch_size : list or tuple of int
267
283
  Size of the patches.
268
284
 
269
285
  Returns
270
286
  -------
271
- Tuple[np.ndarray, None, float, float]
272
- Source patches, mean and standard deviation.
287
+ PatchedOutput
288
+ Dataclass holding the patches and their statistics.
273
289
  """
274
290
  # reshape array
275
291
  reshaped_sample = reshape_array(data, axes)
@@ -0,0 +1,7 @@
1
+ """Functions relating reading and writing image files."""
2
+
3
+ __all__ = ["read", "write", "get_read_func", "get_write_func"]
4
+
5
+ from . import read, write
6
+ from .read import get_read_func
7
+ from .write import get_write_func
@@ -0,0 +1,11 @@
1
+ """Functions relating to reading image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_read_func",
5
+ "read_tiff",
6
+ "read_zarr",
7
+ ]
8
+
9
+ from .get_func import get_read_func
10
+ from .tiff import read_tiff
11
+ from .zarr import read_zarr
@@ -0,0 +1,56 @@
1
+ """Module to get read functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, Protocol, Union
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import read_tiff
11
+
12
+
13
+ # This is very strict, function signature has to match including arg names
14
+ # See WriteFunc notes
15
+ class ReadFunc(Protocol):
16
+ """Protocol for type hinting read functions."""
17
+
18
+ def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
19
+ """
20
+ Type hinted callables must match this function signature (not including self).
21
+
22
+ Parameters
23
+ ----------
24
+ file_path : pathlib.Path
25
+ Path to file.
26
+ *args
27
+ Other positional arguments.
28
+ **kwargs
29
+ Other keyword arguments.
30
+ """
31
+
32
+
33
+ READ_FUNCS: Dict[SupportedData, ReadFunc] = {
34
+ SupportedData.TIFF: read_tiff,
35
+ }
36
+
37
+
38
+ def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
39
+ """
40
+ Get the read function for the data type.
41
+
42
+ Parameters
43
+ ----------
44
+ data_type : SupportedData
45
+ Data type.
46
+
47
+ Returns
48
+ -------
49
+ callable
50
+ Read function.
51
+ """
52
+ if data_type in READ_FUNCS:
53
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
54
+ return READ_FUNCS[data_type]
55
+ else:
56
+ raise NotImplementedError(f"Data type '{data_type}' is not supported.")
@@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
44
44
  ValueError
45
45
  If the axes length is incorrect.
46
46
  """
47
- if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
47
+ if fnmatch(
48
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
49
+ ):
48
50
  try:
49
51
  array = tifffile.imread(file_path)
50
52
  except (ValueError, OSError) as e:
@@ -0,0 +1,9 @@
1
+ """Functions relating to writing image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_write_func",
5
+ "write_tiff",
6
+ ]
7
+
8
+ from .get_func import get_write_func
9
+ from .tiff import write_tiff
@@ -0,0 +1,59 @@
1
+ """Module to get write functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Protocol, Union
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import write_tiff
11
+
12
+
13
+ # This is very strict, arguments have to be called file_path & img
14
+ # Alternative? - doesn't capture *args & **kwargs
15
+ # WriteFunc = Callable[[Path, NDArray], None]
16
+ class WriteFunc(Protocol):
17
+ """Protocol for type hinting write functions."""
18
+
19
+ def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
20
+ """
21
+ Type hinted callables must match this function signature (not including self).
22
+
23
+ Parameters
24
+ ----------
25
+ file_path : pathlib.Path
26
+ Path to file.
27
+ img : numpy.ndarray
28
+ Image data to save.
29
+ *args
30
+ Other positional arguments.
31
+ **kwargs
32
+ Other keyword arguments.
33
+ """
34
+
35
+
36
+ WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
37
+ SupportedData.TIFF: write_tiff,
38
+ }
39
+
40
+
41
+ def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc:
42
+ """
43
+ Get the write function for the data type.
44
+
45
+ Parameters
46
+ ----------
47
+ data_type : SupportedData
48
+ Data type.
49
+
50
+ Returns
51
+ -------
52
+ callable
53
+ Write function.
54
+ """
55
+ if data_type in WRITE_FUNCS:
56
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
57
+ return WRITE_FUNCS[data_type]
58
+ else:
59
+ raise NotImplementedError(f"Data type {data_type} is not supported.")
@@ -0,0 +1,39 @@
1
+ """Write tiff function."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+
6
+ import tifffile
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+
12
+ def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
13
+ """
14
+ Write tiff files.
15
+
16
+ Parameters
17
+ ----------
18
+ file_path : pathlib.Path
19
+ Path to file.
20
+ img : numpy.ndarray
21
+ Image data to save.
22
+ *args
23
+ Positional arguments passed to `tifffile.imwrite`.
24
+ **kwargs
25
+ Keyword arguments passed to `tifffile.imwrite`.
26
+
27
+ Raises
28
+ ------
29
+ ValueError
30
+ When the file extension of `file_path` does not match the Unix shell-style
31
+ pattern '*.tif*'.
32
+ """
33
+ if not fnmatch(
34
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
35
+ ):
36
+ raise ValueError(
37
+ f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
38
+ )
39
+ tifffile.imwrite(file_path, img, *args, **kwargs)
@@ -0,0 +1,17 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ __all__ = [
4
+ "CAREamicsModule",
5
+ "create_careamics_module",
6
+ "TrainDataModule",
7
+ "create_train_datamodule",
8
+ "PredictDataModule",
9
+ "create_predict_datamodule",
10
+ "HyperParametersCallback",
11
+ "ProgressBarCallback",
12
+ ]
13
+
14
+ from .callbacks import HyperParametersCallback, ProgressBarCallback
15
+ from .lightning_module import CAREamicsModule, create_careamics_module
16
+ from .predict_data_module import PredictDataModule, create_predict_datamodule
17
+ from .train_data_module import TrainDataModule, create_train_datamodule
@@ -23,19 +23,19 @@ class CAREamicsModule(L.LightningModule):
23
23
  """
24
24
  CAREamics Lightning module.
25
25
 
26
- This class encapsulates the a PyTorch model along with the training, validation,
26
+ This class encapsulates the PyTorch model along with the training, validation,
27
27
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
28
28
 
29
29
  Parameters
30
30
  ----------
31
- algorithm_config : Union[AlgorithmModel, dict]
31
+ algorithm_config : AlgorithmModel or dict
32
32
  Algorithm configuration.
33
33
 
34
34
  Attributes
35
35
  ----------
36
- model : nn.Module
36
+ model : torch.nn.Module
37
37
  PyTorch model.
38
- loss_func : nn.Module
38
+ loss_func : torch.nn.Module
39
39
  Loss function.
40
40
  optimizer_name : str
41
41
  Optimizer name.
@@ -53,7 +53,7 @@ class CAREamicsModule(L.LightningModule):
53
53
 
54
54
  Parameters
55
55
  ----------
56
- algorithm_config : Union[AlgorithmModel, dict]
56
+ algorithm_config : AlgorithmModel or dict
57
57
  Algorithm configuration.
58
58
  """
59
59
  super().__init__()
@@ -91,7 +91,7 @@ class CAREamicsModule(L.LightningModule):
91
91
 
92
92
  Parameters
93
93
  ----------
94
- batch : Tensor
94
+ batch : torch.Tensor
95
95
  Input batch.
96
96
  batch_idx : Any
97
97
  Batch index.
@@ -114,7 +114,7 @@ class CAREamicsModule(L.LightningModule):
114
114
 
115
115
  Parameters
116
116
  ----------
117
- batch : Tensor
117
+ batch : torch.Tensor
118
118
  Input batch.
119
119
  batch_idx : Any
120
120
  Batch index.
@@ -138,7 +138,7 @@ class CAREamicsModule(L.LightningModule):
138
138
 
139
139
  Parameters
140
140
  ----------
141
- batch : Tensor
141
+ batch : torch.Tensor
142
142
  Input batch.
143
143
  batch_idx : Any
144
144
  Batch index.
@@ -202,101 +202,74 @@ class CAREamicsModule(L.LightningModule):
202
202
  }
203
203
 
204
204
 
205
- class CAREamicsModuleWrapper(CAREamicsModule):
206
- """Class defining the API for CAREamics Lightning layer.
205
+ def create_careamics_module(
206
+ algorithm: Union[SupportedAlgorithm, str],
207
+ loss: Union[SupportedLoss, str],
208
+ architecture: Union[SupportedArchitecture, str],
209
+ model_parameters: Optional[dict] = None,
210
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
211
+ optimizer_parameters: Optional[dict] = None,
212
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
213
+ lr_scheduler_parameters: Optional[dict] = None,
214
+ ) -> CAREamicsModule:
215
+ """Create a CAREamics Lithgning module.
207
216
 
208
- This class exposes parameters used to create an AlgorithmModel instance, triggering
209
- parameters validation.
217
+ This function exposes parameters used to create an AlgorithmModel instance,
218
+ triggering parameters validation.
210
219
 
211
220
  Parameters
212
221
  ----------
213
- algorithm : Union[SupportedAlgorithm, str]
222
+ algorithm : SupportedAlgorithm or str
214
223
  Algorithm to use for training (see SupportedAlgorithm).
215
- loss : Union[SupportedLoss, str]
224
+ loss : SupportedLoss or str
216
225
  Loss function to use for training (see SupportedLoss).
217
- architecture : Union[SupportedArchitecture, str]
226
+ architecture : SupportedArchitecture or str
218
227
  Model architecture to use for training (see SupportedArchitecture).
219
228
  model_parameters : dict, optional
220
229
  Model parameters to use for training, by default {}. Model parameters are
221
230
  defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
222
231
  `careamics.config.architectures`).
223
- optimizer : Union[SupportedOptimizer, str], optional
232
+ optimizer : SupportedOptimizer or str, optional
224
233
  Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
225
234
  optimizer_parameters : dict, optional
226
235
  Optimizer parameters to use for training, as defined in `torch.optim`, by
227
236
  default {}.
228
- lr_scheduler : Union[SupportedScheduler, str], optional
237
+ lr_scheduler : SupportedScheduler or str, optional
229
238
  Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
230
239
  (see SupportedScheduler).
231
240
  lr_scheduler_parameters : dict, optional
232
241
  Learning rate scheduler parameters to use for training, as defined in
233
242
  `torch.optim`, by default {}.
234
- """
235
-
236
- def __init__(
237
- self,
238
- algorithm: Union[SupportedAlgorithm, str],
239
- loss: Union[SupportedLoss, str],
240
- architecture: Union[SupportedArchitecture, str],
241
- model_parameters: Optional[dict] = None,
242
- optimizer: Union[SupportedOptimizer, str] = "Adam",
243
- optimizer_parameters: Optional[dict] = None,
244
- lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
245
- lr_scheduler_parameters: Optional[dict] = None,
246
- ) -> None:
247
- """
248
- Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
249
243
 
250
- Parameters
251
- ----------
252
- algorithm : Union[SupportedAlgorithm, str]
253
- Algorithm to use for training (see SupportedAlgorithm).
254
- loss : Union[SupportedLoss, str]
255
- Loss function to use for training (see SupportedLoss).
256
- architecture : Union[SupportedArchitecture, str]
257
- Model architecture to use for training (see SupportedArchitecture).
258
- model_parameters : dict, optional
259
- Model parameters to use for training, by default {}. Model parameters are
260
- defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
261
- `careamics.config.architectures`).
262
- optimizer : Union[SupportedOptimizer, str], optional
263
- Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
264
- optimizer_parameters : dict, optional
265
- Optimizer parameters to use for training, as defined in `torch.optim`, by
266
- default {}.
267
- lr_scheduler : Union[SupportedScheduler, str], optional
268
- Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
269
- (see SupportedScheduler).
270
- lr_scheduler_parameters : dict, optional
271
- Learning rate scheduler parameters to use for training, as defined in
272
- `torch.optim`, by default {}.
273
- """
274
- # create a AlgorithmModel compatible dictionary
275
- if lr_scheduler_parameters is None:
276
- lr_scheduler_parameters = {}
277
- if optimizer_parameters is None:
278
- optimizer_parameters = {}
279
- if model_parameters is None:
280
- model_parameters = {}
281
- algorithm_configuration = {
282
- "algorithm": algorithm,
283
- "loss": loss,
284
- "optimizer": {
285
- "name": optimizer,
286
- "parameters": optimizer_parameters,
287
- },
288
- "lr_scheduler": {
289
- "name": lr_scheduler,
290
- "parameters": lr_scheduler_parameters,
291
- },
292
- }
293
- model_configuration = {"architecture": architecture}
294
- model_configuration.update(model_parameters)
295
-
296
- # add model parameters to algorithm configuration
297
- algorithm_configuration["model"] = model_configuration
298
-
299
- # call the parent init using an AlgorithmModel instance
300
- super().__init__(AlgorithmConfig(**algorithm_configuration))
301
-
302
- # TODO add load_from_checkpoint wrapper
244
+ Returns
245
+ -------
246
+ CAREamicsModule
247
+ CAREamics Lightning module.
248
+ """
249
+ # create a AlgorithmModel compatible dictionary
250
+ if lr_scheduler_parameters is None:
251
+ lr_scheduler_parameters = {}
252
+ if optimizer_parameters is None:
253
+ optimizer_parameters = {}
254
+ if model_parameters is None:
255
+ model_parameters = {}
256
+ algorithm_configuration = {
257
+ "algorithm": algorithm,
258
+ "loss": loss,
259
+ "optimizer": {
260
+ "name": optimizer,
261
+ "parameters": optimizer_parameters,
262
+ },
263
+ "lr_scheduler": {
264
+ "name": lr_scheduler,
265
+ "parameters": lr_scheduler_parameters,
266
+ },
267
+ }
268
+ model_configuration = {"architecture": architecture}
269
+ model_configuration.update(model_parameters)
270
+
271
+ # add model parameters to algorithm configuration
272
+ algorithm_configuration["model"] = model_configuration
273
+
274
+ # call the parent init using an AlgorithmModel instance
275
+ return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))