careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
44
44
|
ValueError
|
|
45
45
|
If the axes length is incorrect.
|
|
46
46
|
"""
|
|
47
|
-
if fnmatch(
|
|
47
|
+
if fnmatch(
|
|
48
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
49
|
+
):
|
|
48
50
|
try:
|
|
49
51
|
array = tifffile.imread(file_path)
|
|
50
52
|
except (ValueError, OSError) as e:
|
|
@@ -53,13 +55,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
53
55
|
else:
|
|
54
56
|
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
55
57
|
|
|
56
|
-
# check dimensions
|
|
57
|
-
# TODO or should this really be done here? probably in the LightningDataModule
|
|
58
|
-
# TODO this should also be centralized somewhere else (validate_dimensions)
|
|
59
|
-
if len(array.shape) < 2 or len(array.shape) > 6:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
|
|
62
|
-
f"file {file_path})."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
58
|
return array
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Module to get write functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Protocol, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import write_tiff
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# This is very strict, arguments have to be called file_path & img
|
|
14
|
+
# Alternative? - doesn't capture *args & **kwargs
|
|
15
|
+
# WriteFunc = Callable[[Path, NDArray], None]
|
|
16
|
+
class WriteFunc(Protocol):
|
|
17
|
+
"""Protocol for type hinting write functions."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Type hinted callables must match this function signature (not including self).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
file_path : pathlib.Path
|
|
26
|
+
Path to file.
|
|
27
|
+
img : numpy.ndarray
|
|
28
|
+
Image data to save.
|
|
29
|
+
*args
|
|
30
|
+
Other positional arguments.
|
|
31
|
+
**kwargs
|
|
32
|
+
Other keyword arguments.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
|
|
37
|
+
SupportedData.TIFF: write_tiff,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc:
|
|
42
|
+
"""
|
|
43
|
+
Get the write function for the data type.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
data_type : SupportedData
|
|
48
|
+
Data type.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
callable
|
|
53
|
+
Write function.
|
|
54
|
+
"""
|
|
55
|
+
if data_type in WRITE_FUNCS:
|
|
56
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
57
|
+
return WRITE_FUNCS[data_type]
|
|
58
|
+
else:
|
|
59
|
+
raise NotImplementedError(f"Data type {data_type} is not supported.")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Write tiff function."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tifffile
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Write tiff files.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
file_path : pathlib.Path
|
|
19
|
+
Path to file.
|
|
20
|
+
img : numpy.ndarray
|
|
21
|
+
Image data to save.
|
|
22
|
+
*args
|
|
23
|
+
Positional arguments passed to `tifffile.imwrite`.
|
|
24
|
+
**kwargs
|
|
25
|
+
Keyword arguments passed to `tifffile.imwrite`.
|
|
26
|
+
|
|
27
|
+
Raises
|
|
28
|
+
------
|
|
29
|
+
ValueError
|
|
30
|
+
When the file extension of `file_path` does not match the Unix shell-style
|
|
31
|
+
pattern '*.tif*'.
|
|
32
|
+
"""
|
|
33
|
+
if not fnmatch(
|
|
34
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
35
|
+
):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
|
|
38
|
+
)
|
|
39
|
+
tifffile.imwrite(file_path, img, *args, **kwargs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""CAREamics PyTorch Lightning modules."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CAREamicsModule",
|
|
5
|
+
"create_careamics_module",
|
|
6
|
+
"TrainDataModule",
|
|
7
|
+
"create_train_datamodule",
|
|
8
|
+
"PredictDataModule",
|
|
9
|
+
"create_predict_datamodule",
|
|
10
|
+
"HyperParametersCallback",
|
|
11
|
+
"ProgressBarCallback",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
15
|
+
from .lightning_module import CAREamicsModule, create_careamics_module
|
|
16
|
+
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
17
|
+
from .train_data_module import TrainDataModule, create_train_datamodule
|
|
@@ -23,19 +23,19 @@ class CAREamicsModule(L.LightningModule):
|
|
|
23
23
|
"""
|
|
24
24
|
CAREamics Lightning module.
|
|
25
25
|
|
|
26
|
-
This class encapsulates the
|
|
26
|
+
This class encapsulates the PyTorch model along with the training, validation,
|
|
27
27
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
30
30
|
----------
|
|
31
|
-
algorithm_config :
|
|
31
|
+
algorithm_config : AlgorithmModel or dict
|
|
32
32
|
Algorithm configuration.
|
|
33
33
|
|
|
34
34
|
Attributes
|
|
35
35
|
----------
|
|
36
|
-
model : nn.Module
|
|
36
|
+
model : torch.nn.Module
|
|
37
37
|
PyTorch model.
|
|
38
|
-
loss_func : nn.Module
|
|
38
|
+
loss_func : torch.nn.Module
|
|
39
39
|
Loss function.
|
|
40
40
|
optimizer_name : str
|
|
41
41
|
Optimizer name.
|
|
@@ -53,7 +53,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
53
53
|
|
|
54
54
|
Parameters
|
|
55
55
|
----------
|
|
56
|
-
algorithm_config :
|
|
56
|
+
algorithm_config : AlgorithmModel or dict
|
|
57
57
|
Algorithm configuration.
|
|
58
58
|
"""
|
|
59
59
|
super().__init__()
|
|
@@ -91,7 +91,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
91
91
|
|
|
92
92
|
Parameters
|
|
93
93
|
----------
|
|
94
|
-
batch : Tensor
|
|
94
|
+
batch : torch.Tensor
|
|
95
95
|
Input batch.
|
|
96
96
|
batch_idx : Any
|
|
97
97
|
Batch index.
|
|
@@ -114,7 +114,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
114
114
|
|
|
115
115
|
Parameters
|
|
116
116
|
----------
|
|
117
|
-
batch : Tensor
|
|
117
|
+
batch : torch.Tensor
|
|
118
118
|
Input batch.
|
|
119
119
|
batch_idx : Any
|
|
120
120
|
Batch index.
|
|
@@ -138,7 +138,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
138
138
|
|
|
139
139
|
Parameters
|
|
140
140
|
----------
|
|
141
|
-
batch : Tensor
|
|
141
|
+
batch : torch.Tensor
|
|
142
142
|
Input batch.
|
|
143
143
|
batch_idx : Any
|
|
144
144
|
Batch index.
|
|
@@ -148,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
|
|
|
148
148
|
Any
|
|
149
149
|
Model output.
|
|
150
150
|
"""
|
|
151
|
-
|
|
151
|
+
if self._trainer.datamodule.tiled:
|
|
152
|
+
x, *aux = batch
|
|
153
|
+
else:
|
|
154
|
+
x = batch
|
|
155
|
+
aux = []
|
|
152
156
|
|
|
153
157
|
# apply test-time augmentation if available
|
|
154
158
|
# TODO: probably wont work with batch size > 1
|
|
155
159
|
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
156
160
|
tta = ImageRestorationTTA()
|
|
157
|
-
augmented_batch = tta.forward(
|
|
161
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
158
162
|
augmented_output = []
|
|
159
163
|
for augmented in augmented_batch:
|
|
160
164
|
augmented_pred = self.model(augmented)
|
|
@@ -165,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
|
|
|
165
169
|
|
|
166
170
|
# Denormalize the output
|
|
167
171
|
denorm = Denormalize(
|
|
168
|
-
|
|
169
|
-
|
|
172
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
173
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
170
174
|
)
|
|
171
|
-
denormalized_output
|
|
175
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
172
176
|
|
|
173
|
-
if len(aux) > 0:
|
|
174
|
-
return denormalized_output, aux
|
|
177
|
+
if len(aux) > 0: # aux can be tiling information
|
|
178
|
+
return denormalized_output, *aux
|
|
175
179
|
else:
|
|
176
180
|
return denormalized_output
|
|
177
181
|
|
|
@@ -198,101 +202,74 @@ class CAREamicsModule(L.LightningModule):
|
|
|
198
202
|
}
|
|
199
203
|
|
|
200
204
|
|
|
201
|
-
|
|
202
|
-
|
|
205
|
+
def create_careamics_module(
|
|
206
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
207
|
+
loss: Union[SupportedLoss, str],
|
|
208
|
+
architecture: Union[SupportedArchitecture, str],
|
|
209
|
+
model_parameters: Optional[dict] = None,
|
|
210
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
211
|
+
optimizer_parameters: Optional[dict] = None,
|
|
212
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
213
|
+
lr_scheduler_parameters: Optional[dict] = None,
|
|
214
|
+
) -> CAREamicsModule:
|
|
215
|
+
"""Create a CAREamics Lithgning module.
|
|
203
216
|
|
|
204
|
-
This
|
|
205
|
-
parameters validation.
|
|
217
|
+
This function exposes parameters used to create an AlgorithmModel instance,
|
|
218
|
+
triggering parameters validation.
|
|
206
219
|
|
|
207
220
|
Parameters
|
|
208
221
|
----------
|
|
209
|
-
algorithm :
|
|
222
|
+
algorithm : SupportedAlgorithm or str
|
|
210
223
|
Algorithm to use for training (see SupportedAlgorithm).
|
|
211
|
-
loss :
|
|
224
|
+
loss : SupportedLoss or str
|
|
212
225
|
Loss function to use for training (see SupportedLoss).
|
|
213
|
-
architecture :
|
|
226
|
+
architecture : SupportedArchitecture or str
|
|
214
227
|
Model architecture to use for training (see SupportedArchitecture).
|
|
215
228
|
model_parameters : dict, optional
|
|
216
229
|
Model parameters to use for training, by default {}. Model parameters are
|
|
217
230
|
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
218
231
|
`careamics.config.architectures`).
|
|
219
|
-
optimizer :
|
|
232
|
+
optimizer : SupportedOptimizer or str, optional
|
|
220
233
|
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
221
234
|
optimizer_parameters : dict, optional
|
|
222
235
|
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
223
236
|
default {}.
|
|
224
|
-
lr_scheduler :
|
|
237
|
+
lr_scheduler : SupportedScheduler or str, optional
|
|
225
238
|
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
226
239
|
(see SupportedScheduler).
|
|
227
240
|
lr_scheduler_parameters : dict, optional
|
|
228
241
|
Learning rate scheduler parameters to use for training, as defined in
|
|
229
242
|
`torch.optim`, by default {}.
|
|
230
|
-
"""
|
|
231
243
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
265
|
-
(see SupportedScheduler).
|
|
266
|
-
lr_scheduler_parameters : dict, optional
|
|
267
|
-
Learning rate scheduler parameters to use for training, as defined in
|
|
268
|
-
`torch.optim`, by default {}.
|
|
269
|
-
"""
|
|
270
|
-
# create a AlgorithmModel compatible dictionary
|
|
271
|
-
if lr_scheduler_parameters is None:
|
|
272
|
-
lr_scheduler_parameters = {}
|
|
273
|
-
if optimizer_parameters is None:
|
|
274
|
-
optimizer_parameters = {}
|
|
275
|
-
if model_parameters is None:
|
|
276
|
-
model_parameters = {}
|
|
277
|
-
algorithm_configuration = {
|
|
278
|
-
"algorithm": algorithm,
|
|
279
|
-
"loss": loss,
|
|
280
|
-
"optimizer": {
|
|
281
|
-
"name": optimizer,
|
|
282
|
-
"parameters": optimizer_parameters,
|
|
283
|
-
},
|
|
284
|
-
"lr_scheduler": {
|
|
285
|
-
"name": lr_scheduler,
|
|
286
|
-
"parameters": lr_scheduler_parameters,
|
|
287
|
-
},
|
|
288
|
-
}
|
|
289
|
-
model_configuration = {"architecture": architecture}
|
|
290
|
-
model_configuration.update(model_parameters)
|
|
291
|
-
|
|
292
|
-
# add model parameters to algorithm configuration
|
|
293
|
-
algorithm_configuration["model"] = model_configuration
|
|
294
|
-
|
|
295
|
-
# call the parent init using an AlgorithmModel instance
|
|
296
|
-
super().__init__(AlgorithmConfig(**algorithm_configuration))
|
|
297
|
-
|
|
298
|
-
# TODO add load_from_checkpoint wrapper
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
CAREamicsModule
|
|
247
|
+
CAREamics Lightning module.
|
|
248
|
+
"""
|
|
249
|
+
# create a AlgorithmModel compatible dictionary
|
|
250
|
+
if lr_scheduler_parameters is None:
|
|
251
|
+
lr_scheduler_parameters = {}
|
|
252
|
+
if optimizer_parameters is None:
|
|
253
|
+
optimizer_parameters = {}
|
|
254
|
+
if model_parameters is None:
|
|
255
|
+
model_parameters = {}
|
|
256
|
+
algorithm_configuration = {
|
|
257
|
+
"algorithm": algorithm,
|
|
258
|
+
"loss": loss,
|
|
259
|
+
"optimizer": {
|
|
260
|
+
"name": optimizer,
|
|
261
|
+
"parameters": optimizer_parameters,
|
|
262
|
+
},
|
|
263
|
+
"lr_scheduler": {
|
|
264
|
+
"name": lr_scheduler,
|
|
265
|
+
"parameters": lr_scheduler_parameters,
|
|
266
|
+
},
|
|
267
|
+
}
|
|
268
|
+
model_configuration = {"architecture": architecture}
|
|
269
|
+
model_configuration.update(model_parameters)
|
|
270
|
+
|
|
271
|
+
# add model parameters to algorithm configuration
|
|
272
|
+
algorithm_configuration["model"] = model_configuration
|
|
273
|
+
|
|
274
|
+
# call the parent init using an AlgorithmModel instance
|
|
275
|
+
return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))
|