careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,26 +2,93 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Literal, Optional, Union
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
5
|
+
from pydantic import TypeAdapter
|
|
6
|
+
|
|
7
|
+
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
8
|
+
from careamics.config.architectures import UNetModel
|
|
9
|
+
from careamics.config.care_configuration import CAREConfiguration
|
|
10
|
+
from careamics.config.configuration import Configuration
|
|
11
|
+
from careamics.config.data import DataConfig, N2VDataConfig
|
|
12
|
+
from careamics.config.n2n_configuration import N2NConfiguration
|
|
13
|
+
from careamics.config.n2v_configuration import N2VConfiguration
|
|
14
|
+
from careamics.config.support import (
|
|
10
15
|
SupportedArchitecture,
|
|
11
16
|
SupportedPixelManipulation,
|
|
12
17
|
SupportedTransform,
|
|
13
18
|
)
|
|
14
|
-
from .training_model import TrainingConfig
|
|
15
|
-
from .transformations import (
|
|
19
|
+
from careamics.config.training_model import TrainingConfig
|
|
20
|
+
from careamics.config.transformations import (
|
|
21
|
+
N2V_TRANSFORMS_UNION,
|
|
22
|
+
SPATIAL_TRANSFORMS_UNION,
|
|
16
23
|
N2VManipulateModel,
|
|
17
24
|
XYFlipModel,
|
|
18
25
|
XYRandomRotate90Model,
|
|
19
26
|
)
|
|
20
27
|
|
|
21
28
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
) ->
|
|
29
|
+
def configuration_factory(
|
|
30
|
+
configuration: dict[str, Any]
|
|
31
|
+
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
|
|
32
|
+
"""
|
|
33
|
+
Create a configuration for training CAREamics.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
configuration : dict
|
|
38
|
+
Configuration dictionary.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
N2VConfiguration or N2NConfiguration or CAREConfiguration
|
|
43
|
+
Configuration for training CAREamics.
|
|
44
|
+
"""
|
|
45
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
46
|
+
Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
|
|
47
|
+
)
|
|
48
|
+
return adapter.validate_python(configuration)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def algorithm_factory(
|
|
52
|
+
algorithm: dict[str, Any]
|
|
53
|
+
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
|
|
54
|
+
"""
|
|
55
|
+
Create an algorithm model for training CAREamics.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
algorithm : dict
|
|
60
|
+
Algorithm dictionary.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
|
|
65
|
+
Algorithm model for training CAREamics.
|
|
66
|
+
"""
|
|
67
|
+
adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
|
|
68
|
+
return adapter.validate_python(algorithm)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
|
|
72
|
+
"""
|
|
73
|
+
Create a data model for training CAREamics.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
data : dict
|
|
78
|
+
Data dictionary.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
DataConfig or N2VDataConfig
|
|
83
|
+
Data model for training CAREamics.
|
|
84
|
+
"""
|
|
85
|
+
adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
|
|
86
|
+
return adapter.validate_python(data)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _list_spatial_augmentations(
|
|
90
|
+
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
|
|
91
|
+
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
25
92
|
"""
|
|
26
93
|
List the augmentations to apply.
|
|
27
94
|
|
|
@@ -44,7 +111,7 @@ def _list_augmentations(
|
|
|
44
111
|
If there are duplicate transforms.
|
|
45
112
|
"""
|
|
46
113
|
if augmentations is None:
|
|
47
|
-
transform_list: list[
|
|
114
|
+
transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
|
|
48
115
|
XYFlipModel(),
|
|
49
116
|
XYRandomRotate90Model(),
|
|
50
117
|
]
|
|
@@ -123,7 +190,7 @@ def _create_configuration(
|
|
|
123
190
|
patch_size: list[int],
|
|
124
191
|
batch_size: int,
|
|
125
192
|
num_epochs: int,
|
|
126
|
-
augmentations: list[
|
|
193
|
+
augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
|
|
127
194
|
independent_channels: bool,
|
|
128
195
|
loss: Literal["n2v", "mae", "mse"],
|
|
129
196
|
n_channels_in: int,
|
|
@@ -188,21 +255,21 @@ def _create_configuration(
|
|
|
188
255
|
)
|
|
189
256
|
|
|
190
257
|
# algorithm model
|
|
191
|
-
algorithm_config =
|
|
192
|
-
algorithm
|
|
193
|
-
loss
|
|
194
|
-
model
|
|
195
|
-
|
|
258
|
+
algorithm_config = {
|
|
259
|
+
"algorithm": algorithm,
|
|
260
|
+
"loss": loss,
|
|
261
|
+
"model": unet_model,
|
|
262
|
+
}
|
|
196
263
|
|
|
197
264
|
# data model
|
|
198
|
-
data =
|
|
199
|
-
data_type
|
|
200
|
-
axes
|
|
201
|
-
patch_size
|
|
202
|
-
batch_size
|
|
203
|
-
transforms
|
|
204
|
-
dataloader_params
|
|
205
|
-
|
|
265
|
+
data = {
|
|
266
|
+
"data_type": data_type,
|
|
267
|
+
"axes": axes,
|
|
268
|
+
"patch_size": patch_size,
|
|
269
|
+
"batch_size": batch_size,
|
|
270
|
+
"transforms": augmentations,
|
|
271
|
+
"dataloader_params": dataloader_params,
|
|
272
|
+
}
|
|
206
273
|
|
|
207
274
|
# training model
|
|
208
275
|
training = TrainingConfig(
|
|
@@ -212,14 +279,14 @@ def _create_configuration(
|
|
|
212
279
|
)
|
|
213
280
|
|
|
214
281
|
# create configuration
|
|
215
|
-
configuration =
|
|
216
|
-
experiment_name
|
|
217
|
-
algorithm_config
|
|
218
|
-
data_config
|
|
219
|
-
training_config
|
|
220
|
-
|
|
282
|
+
configuration = {
|
|
283
|
+
"experiment_name": experiment_name,
|
|
284
|
+
"algorithm_config": algorithm_config,
|
|
285
|
+
"data_config": data,
|
|
286
|
+
"training_config": training,
|
|
287
|
+
}
|
|
221
288
|
|
|
222
|
-
return configuration
|
|
289
|
+
return configuration_factory(configuration)
|
|
223
290
|
|
|
224
291
|
|
|
225
292
|
# TODO reconsider naming once we officially support LVAE approaches
|
|
@@ -234,8 +301,8 @@ def _create_supervised_configuration(
|
|
|
234
301
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
235
302
|
independent_channels: bool = True,
|
|
236
303
|
loss: Literal["mae", "mse"] = "mae",
|
|
237
|
-
n_channels_in: int =
|
|
238
|
-
n_channels_out: int =
|
|
304
|
+
n_channels_in: Optional[int] = None,
|
|
305
|
+
n_channels_out: Optional[int] = None,
|
|
239
306
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
240
307
|
model_params: Optional[dict] = None,
|
|
241
308
|
dataloader_params: Optional[dict] = None,
|
|
@@ -267,10 +334,10 @@ def _create_supervised_configuration(
|
|
|
267
334
|
Whether to train all channels independently, by default False.
|
|
268
335
|
loss : Literal["mae", "mse"], optional
|
|
269
336
|
Loss function to use, by default "mae".
|
|
270
|
-
n_channels_in : int,
|
|
271
|
-
Number of channels in
|
|
272
|
-
n_channels_out : int,
|
|
273
|
-
Number of channels out
|
|
337
|
+
n_channels_in : int or None, default=None
|
|
338
|
+
Number of channels in.
|
|
339
|
+
n_channels_out : int or None, default=None
|
|
340
|
+
Number of channels out.
|
|
274
341
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
275
342
|
Logger to use, by default "none".
|
|
276
343
|
model_params : dict, optional
|
|
@@ -282,21 +349,31 @@ def _create_supervised_configuration(
|
|
|
282
349
|
-------
|
|
283
350
|
Configuration
|
|
284
351
|
Configuration for training CARE or Noise2Noise.
|
|
352
|
+
|
|
353
|
+
Raises
|
|
354
|
+
------
|
|
355
|
+
ValueError
|
|
356
|
+
If the number of channels is not specified when using channels.
|
|
357
|
+
ValueError
|
|
358
|
+
If the number of channels is specified but "C" is not in the axes.
|
|
285
359
|
"""
|
|
286
360
|
# if there are channels, we need to specify their number
|
|
287
|
-
if "C" in axes and n_channels_in
|
|
288
|
-
raise ValueError(
|
|
289
|
-
|
|
290
|
-
f"(got {n_channels_in} channel)."
|
|
291
|
-
)
|
|
292
|
-
elif "C" not in axes and n_channels_in > 1:
|
|
361
|
+
if "C" in axes and n_channels_in is None:
|
|
362
|
+
raise ValueError("Number of channels in must be specified when using channels ")
|
|
363
|
+
elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
|
|
293
364
|
raise ValueError(
|
|
294
365
|
f"C is not present in the axes, but number of channels is specified "
|
|
295
366
|
f"(got {n_channels_in} channels)."
|
|
296
367
|
)
|
|
297
368
|
|
|
369
|
+
if n_channels_in is None:
|
|
370
|
+
n_channels_in = 1
|
|
371
|
+
|
|
372
|
+
if n_channels_out is None:
|
|
373
|
+
n_channels_out = n_channels_in
|
|
374
|
+
|
|
298
375
|
# augmentations
|
|
299
|
-
|
|
376
|
+
spatial_transform_list = _list_spatial_augmentations(augmentations)
|
|
300
377
|
|
|
301
378
|
return _create_configuration(
|
|
302
379
|
algorithm=algorithm,
|
|
@@ -306,7 +383,7 @@ def _create_supervised_configuration(
|
|
|
306
383
|
patch_size=patch_size,
|
|
307
384
|
batch_size=batch_size,
|
|
308
385
|
num_epochs=num_epochs,
|
|
309
|
-
augmentations=
|
|
386
|
+
augmentations=spatial_transform_list,
|
|
310
387
|
independent_channels=independent_channels,
|
|
311
388
|
loss=loss,
|
|
312
389
|
n_channels_in=n_channels_in,
|
|
@@ -327,8 +404,8 @@ def create_care_configuration(
|
|
|
327
404
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
328
405
|
independent_channels: bool = True,
|
|
329
406
|
loss: Literal["mae", "mse"] = "mae",
|
|
330
|
-
n_channels_in: int =
|
|
331
|
-
n_channels_out: int =
|
|
407
|
+
n_channels_in: Optional[int] = None,
|
|
408
|
+
n_channels_out: Optional[int] = None,
|
|
332
409
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
333
410
|
model_params: Optional[dict] = None,
|
|
334
411
|
dataloader_params: Optional[dict] = None,
|
|
@@ -374,16 +451,16 @@ def create_care_configuration(
|
|
|
374
451
|
and XYRandomRotate90 (in XY) to the images.
|
|
375
452
|
independent_channels : bool, optional
|
|
376
453
|
Whether to train all channels independently, by default False.
|
|
377
|
-
loss : Literal["mae", "mse"],
|
|
378
|
-
Loss function to use
|
|
379
|
-
n_channels_in : int,
|
|
380
|
-
Number of channels in
|
|
381
|
-
n_channels_out : int,
|
|
382
|
-
Number of channels out
|
|
383
|
-
logger : Literal["wandb", "tensorboard", "none"],
|
|
384
|
-
Logger to use
|
|
385
|
-
model_params : dict,
|
|
386
|
-
UNetModel parameters
|
|
454
|
+
loss : Literal["mae", "mse"], default="mae"
|
|
455
|
+
Loss function to use.
|
|
456
|
+
n_channels_in : int or None, default=None
|
|
457
|
+
Number of channels in.
|
|
458
|
+
n_channels_out : int or None, default=None
|
|
459
|
+
Number of channels out.
|
|
460
|
+
logger : Literal["wandb", "tensorboard", "none"], default="none"
|
|
461
|
+
Logger to use.
|
|
462
|
+
model_params : dict, default=None
|
|
463
|
+
UNetModel parameters.
|
|
387
464
|
dataloader_params : dict, optional
|
|
388
465
|
Parameters for the dataloader, see PyTorch notes, by default None.
|
|
389
466
|
|
|
@@ -459,9 +536,6 @@ def create_care_configuration(
|
|
|
459
536
|
... n_channels_out=1 # if applicable
|
|
460
537
|
... )
|
|
461
538
|
"""
|
|
462
|
-
if n_channels_out == -1:
|
|
463
|
-
n_channels_out = n_channels_in
|
|
464
|
-
|
|
465
539
|
return _create_supervised_configuration(
|
|
466
540
|
algorithm="care",
|
|
467
541
|
experiment_name=experiment_name,
|
|
@@ -491,8 +565,8 @@ def create_n2n_configuration(
|
|
|
491
565
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
492
566
|
independent_channels: bool = True,
|
|
493
567
|
loss: Literal["mae", "mse"] = "mae",
|
|
494
|
-
n_channels_in: int =
|
|
495
|
-
n_channels_out: int =
|
|
568
|
+
n_channels_in: Optional[int] = None,
|
|
569
|
+
n_channels_out: Optional[int] = None,
|
|
496
570
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
497
571
|
model_params: Optional[dict] = None,
|
|
498
572
|
dataloader_params: Optional[dict] = None,
|
|
@@ -540,10 +614,10 @@ def create_n2n_configuration(
|
|
|
540
614
|
Whether to train all channels independently, by default False.
|
|
541
615
|
loss : Literal["mae", "mse"], optional
|
|
542
616
|
Loss function to use, by default "mae".
|
|
543
|
-
n_channels_in : int,
|
|
544
|
-
Number of channels in
|
|
545
|
-
n_channels_out : int,
|
|
546
|
-
Number of channels out
|
|
617
|
+
n_channels_in : int or None, default=None
|
|
618
|
+
Number of channels in.
|
|
619
|
+
n_channels_out : int or None, default=None
|
|
620
|
+
Number of channels out.
|
|
547
621
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
548
622
|
Logger to use, by default "none".
|
|
549
623
|
model_params : dict, optional
|
|
@@ -623,9 +697,6 @@ def create_n2n_configuration(
|
|
|
623
697
|
... n_channels_out=1 # if applicable
|
|
624
698
|
... )
|
|
625
699
|
"""
|
|
626
|
-
if n_channels_out == -1:
|
|
627
|
-
n_channels_out = n_channels_in
|
|
628
|
-
|
|
629
700
|
return _create_supervised_configuration(
|
|
630
701
|
algorithm="n2n",
|
|
631
702
|
experiment_name=experiment_name,
|
|
@@ -655,7 +726,7 @@ def create_n2v_configuration(
|
|
|
655
726
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
656
727
|
independent_channels: bool = True,
|
|
657
728
|
use_n2v2: bool = False,
|
|
658
|
-
n_channels: int =
|
|
729
|
+
n_channels: Optional[int] = None,
|
|
659
730
|
roi_size: int = 11,
|
|
660
731
|
masked_pixel_percentage: float = 0.2,
|
|
661
732
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
@@ -727,8 +798,8 @@ def create_n2v_configuration(
|
|
|
727
798
|
Whether to train all channels together, by default True.
|
|
728
799
|
use_n2v2 : bool, optional
|
|
729
800
|
Whether to use N2V2, by default False.
|
|
730
|
-
n_channels : int,
|
|
731
|
-
Number of channels (in and out)
|
|
801
|
+
n_channels : int or None, default=None
|
|
802
|
+
Number of channels (in and out).
|
|
732
803
|
roi_size : int, optional
|
|
733
804
|
N2V pixel manipulation area, by default 11.
|
|
734
805
|
masked_pixel_percentage : float, optional
|
|
@@ -837,19 +908,19 @@ def create_n2v_configuration(
|
|
|
837
908
|
... )
|
|
838
909
|
"""
|
|
839
910
|
# if there are channels, we need to specify their number
|
|
840
|
-
if "C" in axes and n_channels
|
|
841
|
-
raise ValueError(
|
|
842
|
-
|
|
843
|
-
f"(got {n_channels} channel)."
|
|
844
|
-
)
|
|
845
|
-
elif "C" not in axes and n_channels > 1:
|
|
911
|
+
if "C" in axes and n_channels is None:
|
|
912
|
+
raise ValueError("Number of channels must be specified when using channels.")
|
|
913
|
+
elif "C" not in axes and (n_channels is not None and n_channels > 1):
|
|
846
914
|
raise ValueError(
|
|
847
915
|
f"C is not present in the axes, but number of channels is specified "
|
|
848
916
|
f"(got {n_channels} channel)."
|
|
849
917
|
)
|
|
850
918
|
|
|
919
|
+
if n_channels is None:
|
|
920
|
+
n_channels = 1
|
|
921
|
+
|
|
851
922
|
# augmentations
|
|
852
|
-
|
|
923
|
+
spatial_transforms = _list_spatial_augmentations(augmentations)
|
|
853
924
|
|
|
854
925
|
# create the N2VManipulate transform using the supplied parameters
|
|
855
926
|
n2v_transform = N2VManipulateModel(
|
|
@@ -864,7 +935,7 @@ def create_n2v_configuration(
|
|
|
864
935
|
struct_mask_axis=struct_n2v_axis,
|
|
865
936
|
struct_mask_span=struct_n2v_span,
|
|
866
937
|
)
|
|
867
|
-
transform_list
|
|
938
|
+
transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
|
|
868
939
|
|
|
869
940
|
return _create_configuration(
|
|
870
941
|
algorithm="n2v",
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""I/O functions for Configuration objects."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration, configuration_factory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
12
|
+
"""
|
|
13
|
+
Load configuration from a yaml file.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
path : str or Path
|
|
18
|
+
Path to the configuration.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Configuration
|
|
23
|
+
Configuration.
|
|
24
|
+
|
|
25
|
+
Raises
|
|
26
|
+
------
|
|
27
|
+
FileNotFoundError
|
|
28
|
+
If the configuration file does not exist.
|
|
29
|
+
"""
|
|
30
|
+
# load dictionary from yaml
|
|
31
|
+
if not Path(path).exists():
|
|
32
|
+
raise FileNotFoundError(
|
|
33
|
+
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
37
|
+
|
|
38
|
+
return configuration_factory(dictionary)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
42
|
+
"""
|
|
43
|
+
Save configuration to path.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
config : Configuration
|
|
48
|
+
Configuration to save.
|
|
49
|
+
path : str or Path
|
|
50
|
+
Path to a existing folder in which to save the configuration, or to a valid
|
|
51
|
+
configuration file path (uses a .yml or .yaml extension).
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Path
|
|
56
|
+
Path object representing the configuration.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
ValueError
|
|
61
|
+
If the path does not point to an existing directory or .yml file.
|
|
62
|
+
"""
|
|
63
|
+
# make sure path is a Path object
|
|
64
|
+
config_path = Path(path)
|
|
65
|
+
|
|
66
|
+
# check if path is pointing to an existing directory or .yml file
|
|
67
|
+
if config_path.exists():
|
|
68
|
+
if config_path.is_dir():
|
|
69
|
+
config_path = Path(config_path, "config.yml")
|
|
70
|
+
elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# save configuration as dictionary to yaml
|
|
81
|
+
with open(config_path, "w") as f:
|
|
82
|
+
# dump configuration
|
|
83
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
84
|
+
|
|
85
|
+
return config_path
|