careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +38 -0
- careamics/config/algorithms/n2n_algorithm_model.py +30 -0
- careamics/config/algorithms/n2v_algorithm_model.py +29 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +91 -186
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +5 -4
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +3 -3
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +36 -45
- careamics/losses/__init__.py +3 -3
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -4
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
- careamics-0.0.7.dist-info/RECORD +178 -0
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.5.dist-info/RECORD +0 -171
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,27 +1,120 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Literal, Optional, Union
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
3
|
+
from typing import Annotated, Any, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import Discriminator, Tag, TypeAdapter
|
|
6
|
+
|
|
7
|
+
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
8
|
+
from careamics.config.architectures import UNetModel
|
|
9
|
+
from careamics.config.care_configuration import CAREConfiguration
|
|
10
|
+
from careamics.config.configuration import Configuration
|
|
11
|
+
from careamics.config.data import DataConfig, N2VDataConfig
|
|
12
|
+
from careamics.config.n2n_configuration import N2NConfiguration
|
|
13
|
+
from careamics.config.n2v_configuration import N2VConfiguration
|
|
14
|
+
from careamics.config.support import (
|
|
15
|
+
SupportedAlgorithm,
|
|
10
16
|
SupportedArchitecture,
|
|
11
17
|
SupportedPixelManipulation,
|
|
12
18
|
SupportedTransform,
|
|
13
19
|
)
|
|
14
|
-
from .training_model import TrainingConfig
|
|
15
|
-
from .transformations import (
|
|
20
|
+
from careamics.config.training_model import TrainingConfig
|
|
21
|
+
from careamics.config.transformations import (
|
|
22
|
+
N2V_TRANSFORMS_UNION,
|
|
23
|
+
SPATIAL_TRANSFORMS_UNION,
|
|
16
24
|
N2VManipulateModel,
|
|
17
25
|
XYFlipModel,
|
|
18
26
|
XYRandomRotate90Model,
|
|
19
27
|
)
|
|
20
28
|
|
|
21
29
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
|
|
30
|
+
def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
|
|
31
|
+
"""Discriminate algorithm-specific configurations based on the algorithm.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
value : Any
|
|
36
|
+
Value to discriminate.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
str
|
|
41
|
+
Discriminator value.
|
|
42
|
+
"""
|
|
43
|
+
if isinstance(value, dict):
|
|
44
|
+
return value["algorithm_config"]["algorithm"]
|
|
45
|
+
return value.algorithm_config.algorithm
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def configuration_factory(
|
|
49
|
+
configuration: dict[str, Any]
|
|
50
|
+
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
|
|
51
|
+
"""
|
|
52
|
+
Create a configuration for training CAREamics.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
configuration : dict
|
|
57
|
+
Configuration dictionary.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
N2VConfiguration or N2NConfiguration or CAREConfiguration
|
|
62
|
+
Configuration for training CAREamics.
|
|
63
|
+
"""
|
|
64
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
65
|
+
Annotated[
|
|
66
|
+
Union[
|
|
67
|
+
Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
|
|
68
|
+
Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
|
|
69
|
+
Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
|
|
70
|
+
],
|
|
71
|
+
Discriminator(_algorithm_config_discriminator),
|
|
72
|
+
]
|
|
73
|
+
)
|
|
74
|
+
return adapter.validate_python(configuration)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def algorithm_factory(
|
|
78
|
+
algorithm: dict[str, Any]
|
|
79
|
+
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
|
|
80
|
+
"""
|
|
81
|
+
Create an algorithm model for training CAREamics.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
algorithm : dict
|
|
86
|
+
Algorithm dictionary.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
|
|
91
|
+
Algorithm model for training CAREamics.
|
|
92
|
+
"""
|
|
93
|
+
adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
|
|
94
|
+
return adapter.validate_python(algorithm)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
|
|
98
|
+
"""
|
|
99
|
+
Create a data model for training CAREamics.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
data : dict
|
|
104
|
+
Data dictionary.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
DataConfig or N2VDataConfig
|
|
109
|
+
Data model for training CAREamics.
|
|
110
|
+
"""
|
|
111
|
+
adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
|
|
112
|
+
return adapter.validate_python(data)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _list_spatial_augmentations(
|
|
116
|
+
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
|
|
117
|
+
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
25
118
|
"""
|
|
26
119
|
List the augmentations to apply.
|
|
27
120
|
|
|
@@ -44,7 +137,7 @@ def _list_augmentations(
|
|
|
44
137
|
If there are duplicate transforms.
|
|
45
138
|
"""
|
|
46
139
|
if augmentations is None:
|
|
47
|
-
transform_list: list[
|
|
140
|
+
transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
|
|
48
141
|
XYFlipModel(),
|
|
49
142
|
XYRandomRotate90Model(),
|
|
50
143
|
]
|
|
@@ -123,7 +216,7 @@ def _create_configuration(
|
|
|
123
216
|
patch_size: list[int],
|
|
124
217
|
batch_size: int,
|
|
125
218
|
num_epochs: int,
|
|
126
|
-
augmentations: list[
|
|
219
|
+
augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
|
|
127
220
|
independent_channels: bool,
|
|
128
221
|
loss: Literal["n2v", "mae", "mse"],
|
|
129
222
|
n_channels_in: int,
|
|
@@ -131,7 +224,8 @@ def _create_configuration(
|
|
|
131
224
|
logger: Literal["wandb", "tensorboard", "none"],
|
|
132
225
|
use_n2v2: bool = False,
|
|
133
226
|
model_params: Optional[dict] = None,
|
|
134
|
-
|
|
227
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
228
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
135
229
|
) -> Configuration:
|
|
136
230
|
"""
|
|
137
231
|
Create a configuration for training N2V, CARE or Noise2Noise.
|
|
@@ -169,8 +263,10 @@ def _create_configuration(
|
|
|
169
263
|
Whether to use N2V2, by default False.
|
|
170
264
|
model_params : dict
|
|
171
265
|
UNetModel parameters.
|
|
172
|
-
|
|
173
|
-
Parameters for the dataloader, see PyTorch notes, by default None.
|
|
266
|
+
train_dataloader_params : dict
|
|
267
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
268
|
+
val_dataloader_params : dict
|
|
269
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
174
270
|
|
|
175
271
|
Returns
|
|
176
272
|
-------
|
|
@@ -188,21 +284,25 @@ def _create_configuration(
|
|
|
188
284
|
)
|
|
189
285
|
|
|
190
286
|
# algorithm model
|
|
191
|
-
algorithm_config =
|
|
192
|
-
algorithm
|
|
193
|
-
loss
|
|
194
|
-
model
|
|
195
|
-
|
|
287
|
+
algorithm_config = {
|
|
288
|
+
"algorithm": algorithm,
|
|
289
|
+
"loss": loss,
|
|
290
|
+
"model": unet_model,
|
|
291
|
+
}
|
|
196
292
|
|
|
197
293
|
# data model
|
|
198
|
-
data =
|
|
199
|
-
data_type
|
|
200
|
-
axes
|
|
201
|
-
patch_size
|
|
202
|
-
batch_size
|
|
203
|
-
transforms
|
|
204
|
-
|
|
205
|
-
|
|
294
|
+
data = {
|
|
295
|
+
"data_type": data_type,
|
|
296
|
+
"axes": axes,
|
|
297
|
+
"patch_size": patch_size,
|
|
298
|
+
"batch_size": batch_size,
|
|
299
|
+
"transforms": augmentations,
|
|
300
|
+
}
|
|
301
|
+
# Don't override defaults set in DataConfig class
|
|
302
|
+
if train_dataloader_params is not None:
|
|
303
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
304
|
+
if val_dataloader_params is not None:
|
|
305
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
206
306
|
|
|
207
307
|
# training model
|
|
208
308
|
training = TrainingConfig(
|
|
@@ -212,14 +312,14 @@ def _create_configuration(
|
|
|
212
312
|
)
|
|
213
313
|
|
|
214
314
|
# create configuration
|
|
215
|
-
configuration =
|
|
216
|
-
experiment_name
|
|
217
|
-
algorithm_config
|
|
218
|
-
data_config
|
|
219
|
-
training_config
|
|
220
|
-
|
|
315
|
+
configuration = {
|
|
316
|
+
"experiment_name": experiment_name,
|
|
317
|
+
"algorithm_config": algorithm_config,
|
|
318
|
+
"data_config": data,
|
|
319
|
+
"training_config": training,
|
|
320
|
+
}
|
|
221
321
|
|
|
222
|
-
return configuration
|
|
322
|
+
return configuration_factory(configuration)
|
|
223
323
|
|
|
224
324
|
|
|
225
325
|
# TODO reconsider naming once we officially support LVAE approaches
|
|
@@ -238,7 +338,8 @@ def _create_supervised_configuration(
|
|
|
238
338
|
n_channels_out: Optional[int] = None,
|
|
239
339
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
240
340
|
model_params: Optional[dict] = None,
|
|
241
|
-
|
|
341
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
342
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
242
343
|
) -> Configuration:
|
|
243
344
|
"""
|
|
244
345
|
Create a configuration for training CARE or Noise2Noise.
|
|
@@ -275,8 +376,10 @@ def _create_supervised_configuration(
|
|
|
275
376
|
Logger to use, by default "none".
|
|
276
377
|
model_params : dict, optional
|
|
277
378
|
UNetModel parameters, by default {}.
|
|
278
|
-
|
|
279
|
-
Parameters for the dataloader, see PyTorch notes, by default None.
|
|
379
|
+
train_dataloader_params : dict
|
|
380
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
381
|
+
val_dataloader_params : dict
|
|
382
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
280
383
|
|
|
281
384
|
Returns
|
|
282
385
|
-------
|
|
@@ -306,7 +409,7 @@ def _create_supervised_configuration(
|
|
|
306
409
|
n_channels_out = n_channels_in
|
|
307
410
|
|
|
308
411
|
# augmentations
|
|
309
|
-
|
|
412
|
+
spatial_transform_list = _list_spatial_augmentations(augmentations)
|
|
310
413
|
|
|
311
414
|
return _create_configuration(
|
|
312
415
|
algorithm=algorithm,
|
|
@@ -316,14 +419,15 @@ def _create_supervised_configuration(
|
|
|
316
419
|
patch_size=patch_size,
|
|
317
420
|
batch_size=batch_size,
|
|
318
421
|
num_epochs=num_epochs,
|
|
319
|
-
augmentations=
|
|
422
|
+
augmentations=spatial_transform_list,
|
|
320
423
|
independent_channels=independent_channels,
|
|
321
424
|
loss=loss,
|
|
322
425
|
n_channels_in=n_channels_in,
|
|
323
426
|
n_channels_out=n_channels_out,
|
|
324
427
|
logger=logger,
|
|
325
428
|
model_params=model_params,
|
|
326
|
-
|
|
429
|
+
train_dataloader_params=train_dataloader_params,
|
|
430
|
+
val_dataloader_params=val_dataloader_params,
|
|
327
431
|
)
|
|
328
432
|
|
|
329
433
|
|
|
@@ -341,7 +445,8 @@ def create_care_configuration(
|
|
|
341
445
|
n_channels_out: Optional[int] = None,
|
|
342
446
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
343
447
|
model_params: Optional[dict] = None,
|
|
344
|
-
|
|
448
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
449
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
345
450
|
) -> Configuration:
|
|
346
451
|
"""
|
|
347
452
|
Create a configuration for training CARE.
|
|
@@ -394,8 +499,14 @@ def create_care_configuration(
|
|
|
394
499
|
Logger to use.
|
|
395
500
|
model_params : dict, default=None
|
|
396
501
|
UNetModel parameters.
|
|
397
|
-
|
|
398
|
-
Parameters for the dataloader, see PyTorch
|
|
502
|
+
train_dataloader_params : dict, optional
|
|
503
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
504
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
505
|
+
the `GeneralDataConfig`.
|
|
506
|
+
val_dataloader_params : dict, optional
|
|
507
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
508
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
509
|
+
`GeneralDataConfig`.
|
|
399
510
|
|
|
400
511
|
Returns
|
|
401
512
|
-------
|
|
@@ -484,7 +595,8 @@ def create_care_configuration(
|
|
|
484
595
|
n_channels_out=n_channels_out,
|
|
485
596
|
logger=logger,
|
|
486
597
|
model_params=model_params,
|
|
487
|
-
|
|
598
|
+
train_dataloader_params=train_dataloader_params,
|
|
599
|
+
val_dataloader_params=val_dataloader_params,
|
|
488
600
|
)
|
|
489
601
|
|
|
490
602
|
|
|
@@ -502,7 +614,8 @@ def create_n2n_configuration(
|
|
|
502
614
|
n_channels_out: Optional[int] = None,
|
|
503
615
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
504
616
|
model_params: Optional[dict] = None,
|
|
505
|
-
|
|
617
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
618
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
506
619
|
) -> Configuration:
|
|
507
620
|
"""
|
|
508
621
|
Create a configuration for training Noise2Noise.
|
|
@@ -555,8 +668,14 @@ def create_n2n_configuration(
|
|
|
555
668
|
Logger to use, by default "none".
|
|
556
669
|
model_params : dict, optional
|
|
557
670
|
UNetModel parameters, by default {}.
|
|
558
|
-
|
|
559
|
-
Parameters for the dataloader, see PyTorch
|
|
671
|
+
train_dataloader_params : dict, optional
|
|
672
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
673
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
674
|
+
the `GeneralDataConfig`.
|
|
675
|
+
val_dataloader_params : dict, optional
|
|
676
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
677
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
678
|
+
`GeneralDataConfig`.
|
|
560
679
|
|
|
561
680
|
Returns
|
|
562
681
|
-------
|
|
@@ -645,7 +764,8 @@ def create_n2n_configuration(
|
|
|
645
764
|
n_channels_out=n_channels_out,
|
|
646
765
|
logger=logger,
|
|
647
766
|
model_params=model_params,
|
|
648
|
-
|
|
767
|
+
train_dataloader_params=train_dataloader_params,
|
|
768
|
+
val_dataloader_params=val_dataloader_params,
|
|
649
769
|
)
|
|
650
770
|
|
|
651
771
|
|
|
@@ -666,7 +786,8 @@ def create_n2v_configuration(
|
|
|
666
786
|
struct_n2v_span: int = 5,
|
|
667
787
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
668
788
|
model_params: Optional[dict] = None,
|
|
669
|
-
|
|
789
|
+
train_dataloader_params: Optional[dict[str, Any]] = None,
|
|
790
|
+
val_dataloader_params: Optional[dict[str, Any]] = None,
|
|
670
791
|
) -> Configuration:
|
|
671
792
|
"""
|
|
672
793
|
Create a configuration for training Noise2Void.
|
|
@@ -745,8 +866,14 @@ def create_n2v_configuration(
|
|
|
745
866
|
Logger to use, by default "none".
|
|
746
867
|
model_params : dict, optional
|
|
747
868
|
UNetModel parameters, by default None.
|
|
748
|
-
|
|
749
|
-
Parameters for the dataloader, see PyTorch
|
|
869
|
+
train_dataloader_params : dict, optional
|
|
870
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
871
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
872
|
+
the `GeneralDataConfig`.
|
|
873
|
+
val_dataloader_params : dict, optional
|
|
874
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
875
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
876
|
+
`GeneralDataConfig`.
|
|
750
877
|
|
|
751
878
|
Returns
|
|
752
879
|
-------
|
|
@@ -853,7 +980,7 @@ def create_n2v_configuration(
|
|
|
853
980
|
n_channels = 1
|
|
854
981
|
|
|
855
982
|
# augmentations
|
|
856
|
-
|
|
983
|
+
spatial_transforms = _list_spatial_augmentations(augmentations)
|
|
857
984
|
|
|
858
985
|
# create the N2VManipulate transform using the supplied parameters
|
|
859
986
|
n2v_transform = N2VManipulateModel(
|
|
@@ -868,7 +995,7 @@ def create_n2v_configuration(
|
|
|
868
995
|
struct_mask_axis=struct_n2v_axis,
|
|
869
996
|
struct_mask_span=struct_n2v_span,
|
|
870
997
|
)
|
|
871
|
-
transform_list
|
|
998
|
+
transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
|
|
872
999
|
|
|
873
1000
|
return _create_configuration(
|
|
874
1001
|
algorithm="n2v",
|
|
@@ -886,5 +1013,6 @@ def create_n2v_configuration(
|
|
|
886
1013
|
n_channels_out=n_channels,
|
|
887
1014
|
logger=logger,
|
|
888
1015
|
model_params=model_params,
|
|
889
|
-
|
|
1016
|
+
train_dataloader_params=train_dataloader_params,
|
|
1017
|
+
val_dataloader_params=val_dataloader_params,
|
|
890
1018
|
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""I/O functions for Configuration objects."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration, configuration_factory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
12
|
+
"""
|
|
13
|
+
Load configuration from a yaml file.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
path : str or Path
|
|
18
|
+
Path to the configuration.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Configuration
|
|
23
|
+
Configuration.
|
|
24
|
+
|
|
25
|
+
Raises
|
|
26
|
+
------
|
|
27
|
+
FileNotFoundError
|
|
28
|
+
If the configuration file does not exist.
|
|
29
|
+
"""
|
|
30
|
+
# load dictionary from yaml
|
|
31
|
+
if not Path(path).exists():
|
|
32
|
+
raise FileNotFoundError(
|
|
33
|
+
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
37
|
+
|
|
38
|
+
return configuration_factory(dictionary)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
42
|
+
"""
|
|
43
|
+
Save configuration to path.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
config : Configuration
|
|
48
|
+
Configuration to save.
|
|
49
|
+
path : str or Path
|
|
50
|
+
Path to a existing folder in which to save the configuration, or to a valid
|
|
51
|
+
configuration file path (uses a .yml or .yaml extension).
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Path
|
|
56
|
+
Path object representing the configuration.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
ValueError
|
|
61
|
+
If the path does not point to an existing directory or .yml file.
|
|
62
|
+
"""
|
|
63
|
+
# make sure path is a Path object
|
|
64
|
+
config_path = Path(path)
|
|
65
|
+
|
|
66
|
+
# check if path is pointing to an existing directory or .yml file
|
|
67
|
+
if config_path.exists():
|
|
68
|
+
if config_path.is_dir():
|
|
69
|
+
config_path = Path(config_path, "config.yml")
|
|
70
|
+
elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# save configuration as dictionary to yaml
|
|
81
|
+
with open(config_path, "w") as f:
|
|
82
|
+
# dump configuration
|
|
83
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
84
|
+
|
|
85
|
+
return config_path
|