careamics 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +49 -49
- careamics/cli/conf.py +6 -6
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/algorithms/vae_algorithm_model.py +4 -4
- careamics/config/callback_model.py +8 -8
- careamics/config/configuration_factories.py +49 -49
- careamics/config/data/data_model.py +7 -13
- careamics/config/data/ng_data_model.py +8 -14
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -10
- careamics/config/likelihood_model.py +2 -2
- careamics/config/nm_model.py +5 -7
- careamics/config/training_model.py +4 -4
- careamics/config/transformations/normalize_model.py +3 -3
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -4
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +12 -14
- careamics/lightning/predict_data_module.py +8 -8
- careamics/lightning/train_data_module.py +11 -11
- careamics/losses/lvae/losses.py +9 -9
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +4 -4
- careamics/models/layers.py +5 -5
- careamics/models/unet.py +16 -10
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +3 -5
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/RECORD +57 -57
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Union, overload
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from numpy.typing import NDArray
|
|
@@ -79,8 +79,8 @@ class CAREamist:
|
|
|
79
79
|
def __init__( # numpydoc ignore=GL08
|
|
80
80
|
self,
|
|
81
81
|
source: Union[Path, str],
|
|
82
|
-
work_dir:
|
|
83
|
-
callbacks:
|
|
82
|
+
work_dir: Union[Path, str] | None = None,
|
|
83
|
+
callbacks: list[Callback] | None = None,
|
|
84
84
|
enable_progress_bar: bool = True,
|
|
85
85
|
) -> None: ...
|
|
86
86
|
|
|
@@ -88,16 +88,16 @@ class CAREamist:
|
|
|
88
88
|
def __init__( # numpydoc ignore=GL08
|
|
89
89
|
self,
|
|
90
90
|
source: Configuration,
|
|
91
|
-
work_dir:
|
|
92
|
-
callbacks:
|
|
91
|
+
work_dir: Union[Path, str] | None = None,
|
|
92
|
+
callbacks: list[Callback] | None = None,
|
|
93
93
|
enable_progress_bar: bool = True,
|
|
94
94
|
) -> None: ...
|
|
95
95
|
|
|
96
96
|
def __init__(
|
|
97
97
|
self,
|
|
98
98
|
source: Union[Path, str, Configuration],
|
|
99
|
-
work_dir:
|
|
100
|
-
callbacks:
|
|
99
|
+
work_dir: Union[Path, str] | None = None,
|
|
100
|
+
callbacks: list[Callback] | None = None,
|
|
101
101
|
enable_progress_bar: bool = True,
|
|
102
102
|
) -> None:
|
|
103
103
|
"""
|
|
@@ -222,11 +222,11 @@ class CAREamist:
|
|
|
222
222
|
)
|
|
223
223
|
|
|
224
224
|
# place holder for the datamodules
|
|
225
|
-
self.train_datamodule:
|
|
226
|
-
self.pred_datamodule:
|
|
225
|
+
self.train_datamodule: TrainDataModule | None = None
|
|
226
|
+
self.pred_datamodule: PredictDataModule | None = None
|
|
227
227
|
|
|
228
228
|
def _define_callbacks(
|
|
229
|
-
self, callbacks:
|
|
229
|
+
self, callbacks: list[Callback] | None, enable_progress_bar: bool
|
|
230
230
|
) -> None:
|
|
231
231
|
"""Define the callbacks for the training loop.
|
|
232
232
|
|
|
@@ -288,11 +288,11 @@ class CAREamist:
|
|
|
288
288
|
def train(
|
|
289
289
|
self,
|
|
290
290
|
*,
|
|
291
|
-
datamodule:
|
|
292
|
-
train_source:
|
|
293
|
-
val_source:
|
|
294
|
-
train_target:
|
|
295
|
-
val_target:
|
|
291
|
+
datamodule: TrainDataModule | None = None,
|
|
292
|
+
train_source: Union[Path, str, NDArray] | None = None,
|
|
293
|
+
val_source: Union[Path, str, NDArray] | None = None,
|
|
294
|
+
train_target: Union[Path, str, NDArray] | None = None,
|
|
295
|
+
val_target: Union[Path, str, NDArray] | None = None,
|
|
296
296
|
use_in_memory: bool = True,
|
|
297
297
|
val_percentage: float = 0.1,
|
|
298
298
|
val_minimum_split: int = 1,
|
|
@@ -443,9 +443,9 @@ class CAREamist:
|
|
|
443
443
|
def _train_on_array(
|
|
444
444
|
self,
|
|
445
445
|
train_data: NDArray,
|
|
446
|
-
val_data:
|
|
447
|
-
train_target:
|
|
448
|
-
val_target:
|
|
446
|
+
val_data: NDArray | None = None,
|
|
447
|
+
train_target: NDArray | None = None,
|
|
448
|
+
val_target: NDArray | None = None,
|
|
449
449
|
val_percentage: float = 0.1,
|
|
450
450
|
val_minimum_split: int = 5,
|
|
451
451
|
) -> None:
|
|
@@ -484,9 +484,9 @@ class CAREamist:
|
|
|
484
484
|
def _train_on_path(
|
|
485
485
|
self,
|
|
486
486
|
path_to_train_data: Union[Path, str],
|
|
487
|
-
path_to_val_data:
|
|
488
|
-
path_to_train_target:
|
|
489
|
-
path_to_val_target:
|
|
487
|
+
path_to_val_data: Union[Path, str] | None = None,
|
|
488
|
+
path_to_train_target: Union[Path, str] | None = None,
|
|
489
|
+
path_to_val_target: Union[Path, str] | None = None,
|
|
490
490
|
use_in_memory: bool = True,
|
|
491
491
|
val_percentage: float = 0.1,
|
|
492
492
|
val_minimum_split: int = 1,
|
|
@@ -549,13 +549,13 @@ class CAREamist:
|
|
|
549
549
|
source: Union[Path, str],
|
|
550
550
|
*,
|
|
551
551
|
batch_size: int = 1,
|
|
552
|
-
tile_size:
|
|
553
|
-
tile_overlap:
|
|
554
|
-
axes:
|
|
555
|
-
data_type:
|
|
552
|
+
tile_size: tuple[int, ...] | None = None,
|
|
553
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
554
|
+
axes: str | None = None,
|
|
555
|
+
data_type: Literal["tiff", "custom"] | None = None,
|
|
556
556
|
tta_transforms: bool = False,
|
|
557
|
-
dataloader_params:
|
|
558
|
-
read_source_func:
|
|
557
|
+
dataloader_params: dict | None = None,
|
|
558
|
+
read_source_func: Callable | None = None,
|
|
559
559
|
extension_filter: str = "",
|
|
560
560
|
) -> Union[list[NDArray], NDArray]: ...
|
|
561
561
|
|
|
@@ -565,12 +565,12 @@ class CAREamist:
|
|
|
565
565
|
source: NDArray,
|
|
566
566
|
*,
|
|
567
567
|
batch_size: int = 1,
|
|
568
|
-
tile_size:
|
|
569
|
-
tile_overlap:
|
|
570
|
-
axes:
|
|
571
|
-
data_type:
|
|
568
|
+
tile_size: tuple[int, ...] | None = None,
|
|
569
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
570
|
+
axes: str | None = None,
|
|
571
|
+
data_type: Literal["array"] | None = None,
|
|
572
572
|
tta_transforms: bool = False,
|
|
573
|
-
dataloader_params:
|
|
573
|
+
dataloader_params: dict | None = None,
|
|
574
574
|
) -> Union[list[NDArray], NDArray]: ...
|
|
575
575
|
|
|
576
576
|
def predict(
|
|
@@ -578,13 +578,13 @@ class CAREamist:
|
|
|
578
578
|
source: Union[PredictDataModule, Path, str, NDArray],
|
|
579
579
|
*,
|
|
580
580
|
batch_size: int = 1,
|
|
581
|
-
tile_size:
|
|
582
|
-
tile_overlap:
|
|
583
|
-
axes:
|
|
584
|
-
data_type:
|
|
581
|
+
tile_size: tuple[int, ...] | None = None,
|
|
582
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
583
|
+
axes: str | None = None,
|
|
584
|
+
data_type: Literal["array", "tiff", "custom"] | None = None,
|
|
585
585
|
tta_transforms: bool = False,
|
|
586
|
-
dataloader_params:
|
|
587
|
-
read_source_func:
|
|
586
|
+
dataloader_params: dict | None = None,
|
|
587
|
+
read_source_func: Callable | None = None,
|
|
588
588
|
extension_filter: str = "",
|
|
589
589
|
**kwargs: Any,
|
|
590
590
|
) -> Union[list[NDArray], NDArray]:
|
|
@@ -704,18 +704,18 @@ class CAREamist:
|
|
|
704
704
|
source: Union[PredictDataModule, Path, str],
|
|
705
705
|
*,
|
|
706
706
|
batch_size: int = 1,
|
|
707
|
-
tile_size:
|
|
708
|
-
tile_overlap:
|
|
709
|
-
axes:
|
|
710
|
-
data_type:
|
|
707
|
+
tile_size: tuple[int, ...] | None = None,
|
|
708
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
709
|
+
axes: str | None = None,
|
|
710
|
+
data_type: Literal["tiff", "custom"] | None = None,
|
|
711
711
|
tta_transforms: bool = False,
|
|
712
|
-
dataloader_params:
|
|
713
|
-
read_source_func:
|
|
712
|
+
dataloader_params: dict | None = None,
|
|
713
|
+
read_source_func: Callable | None = None,
|
|
714
714
|
extension_filter: str = "",
|
|
715
715
|
write_type: Literal["tiff", "custom"] = "tiff",
|
|
716
|
-
write_extension:
|
|
717
|
-
write_func:
|
|
718
|
-
write_func_kwargs:
|
|
716
|
+
write_extension: str | None = None,
|
|
717
|
+
write_func: WriteFunc | None = None,
|
|
718
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
719
719
|
prediction_dir: Union[Path, str] = "predictions",
|
|
720
720
|
**kwargs,
|
|
721
721
|
) -> None:
|
|
@@ -885,8 +885,8 @@ class CAREamist:
|
|
|
885
885
|
authors: list[dict],
|
|
886
886
|
general_description: str,
|
|
887
887
|
data_description: str,
|
|
888
|
-
covers:
|
|
889
|
-
channel_names:
|
|
888
|
+
covers: list[Union[Path, str]] | None = None,
|
|
889
|
+
channel_names: list[str] | None = None,
|
|
890
890
|
model_version: str = "0.1.0",
|
|
891
891
|
) -> None:
|
|
892
892
|
"""Export the model to the BioImage Model Zoo format.
|
careamics/cli/conf.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import sys
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Annotated
|
|
6
|
+
from typing import Annotated
|
|
7
7
|
|
|
8
8
|
import click
|
|
9
9
|
import typer
|
|
@@ -135,10 +135,10 @@ def care( # numpydoc ignore=PR01
|
|
|
135
135
|
),
|
|
136
136
|
] = "mae",
|
|
137
137
|
n_channels_in: Annotated[
|
|
138
|
-
|
|
138
|
+
int | None, typer.Option(help="Number of channels in")
|
|
139
139
|
] = None,
|
|
140
140
|
n_channels_out: Annotated[
|
|
141
|
-
|
|
141
|
+
int | None, typer.Option(help="Number of channels out")
|
|
142
142
|
] = None,
|
|
143
143
|
logger: Annotated[
|
|
144
144
|
click.Choice,
|
|
@@ -222,10 +222,10 @@ def n2n( # numpydoc ignore=PR01
|
|
|
222
222
|
),
|
|
223
223
|
] = "mae",
|
|
224
224
|
n_channels_in: Annotated[
|
|
225
|
-
|
|
225
|
+
int | None, typer.Option(help="Number of channels in")
|
|
226
226
|
] = None,
|
|
227
227
|
n_channels_out: Annotated[
|
|
228
|
-
|
|
228
|
+
int | None, typer.Option(help="Number of channels out")
|
|
229
229
|
] = None,
|
|
230
230
|
logger: Annotated[
|
|
231
231
|
click.Choice,
|
|
@@ -300,7 +300,7 @@ def n2v( # numpydoc ignore=PR01
|
|
|
300
300
|
] = True,
|
|
301
301
|
use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
|
|
302
302
|
n_channels: Annotated[
|
|
303
|
-
|
|
303
|
+
int | None, typer.Option(help="Number of channels (in and out)")
|
|
304
304
|
] = None,
|
|
305
305
|
roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
|
|
306
306
|
masked_pixel_percentage: Annotated[
|
careamics/cli/main.py
CHANGED
|
@@ -7,7 +7,7 @@ its implementation is contained in the conf.py file.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Annotated
|
|
10
|
+
from typing import Annotated
|
|
11
11
|
|
|
12
12
|
import click
|
|
13
13
|
import typer
|
|
@@ -47,7 +47,7 @@ def train( # numpydoc ignore=PR01
|
|
|
47
47
|
),
|
|
48
48
|
],
|
|
49
49
|
train_target: Annotated[
|
|
50
|
-
|
|
50
|
+
Path | None,
|
|
51
51
|
typer.Option(
|
|
52
52
|
"--train-target",
|
|
53
53
|
"-tt",
|
|
@@ -58,7 +58,7 @@ def train( # numpydoc ignore=PR01
|
|
|
58
58
|
),
|
|
59
59
|
] = None,
|
|
60
60
|
val_source: Annotated[
|
|
61
|
-
|
|
61
|
+
Path | None,
|
|
62
62
|
typer.Option(
|
|
63
63
|
"--val-source",
|
|
64
64
|
"-vs",
|
|
@@ -69,7 +69,7 @@ def train( # numpydoc ignore=PR01
|
|
|
69
69
|
),
|
|
70
70
|
] = None,
|
|
71
71
|
val_target: Annotated[
|
|
72
|
-
|
|
72
|
+
Path | None,
|
|
73
73
|
typer.Option(
|
|
74
74
|
"--val-target",
|
|
75
75
|
"-vt",
|
|
@@ -96,7 +96,7 @@ def train( # numpydoc ignore=PR01
|
|
|
96
96
|
typer.Option(help="Minimum number of files to use for validation,"),
|
|
97
97
|
] = 1,
|
|
98
98
|
work_dir: Annotated[
|
|
99
|
-
|
|
99
|
+
Path | None,
|
|
100
100
|
typer.Option(
|
|
101
101
|
"--work-dir",
|
|
102
102
|
"-wd",
|
|
@@ -142,7 +142,7 @@ def predict( # numpydoc ignore=PR01
|
|
|
142
142
|
],
|
|
143
143
|
batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
|
|
144
144
|
tile_size: Annotated[
|
|
145
|
-
|
|
145
|
+
click.Tuple | None,
|
|
146
146
|
typer.Option(
|
|
147
147
|
help=(
|
|
148
148
|
"Size of the tiles to use for prediction, (if the data "
|
|
@@ -164,7 +164,7 @@ def predict( # numpydoc ignore=PR01
|
|
|
164
164
|
),
|
|
165
165
|
] = (48, 48, -1),
|
|
166
166
|
axes: Annotated[
|
|
167
|
-
|
|
167
|
+
str | None,
|
|
168
168
|
typer.Option(
|
|
169
169
|
help="Axes of the input data. If unused the data is assumed to have the "
|
|
170
170
|
"same axes as the original training data."
|
|
@@ -190,7 +190,7 @@ def predict( # numpydoc ignore=PR01
|
|
|
190
190
|
] = "tiff",
|
|
191
191
|
# TODO: could make dataloader_params as json, necessary?
|
|
192
192
|
work_dir: Annotated[
|
|
193
|
-
|
|
193
|
+
Path | None,
|
|
194
194
|
typer.Option(
|
|
195
195
|
"--work-dir",
|
|
196
196
|
"-wd",
|
careamics/cli/utils.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
"""Utility functions for the CAREamics CLI."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
|
|
6
4
|
def handle_2D_3D_callback(
|
|
7
|
-
value:
|
|
8
|
-
) ->
|
|
5
|
+
value: tuple[int, int, int] | None,
|
|
6
|
+
) -> tuple[int, ...] | None:
|
|
9
7
|
"""
|
|
10
8
|
Callback for options that require 2D or 3D inputs.
|
|
11
9
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal
|
|
6
|
+
from typing import Literal
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel, ConfigDict, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
@@ -45,9 +45,9 @@ class VAEBasedAlgorithm(BaseModel):
|
|
|
45
45
|
# NOTE: these are all configs (pydantic models)
|
|
46
46
|
loss: LVAELossConfig
|
|
47
47
|
model: LVAEModel
|
|
48
|
-
noise_model:
|
|
49
|
-
noise_model_likelihood:
|
|
50
|
-
gaussian_likelihood:
|
|
48
|
+
noise_model: MultiChannelNMConfig | None = None
|
|
49
|
+
noise_model_likelihood: NMLikelihoodConfig | None = None
|
|
50
|
+
gaussian_likelihood: GaussianLikelihoodConfig | None = None
|
|
51
51
|
|
|
52
52
|
# Optional fields
|
|
53
53
|
optimizer: OptimizerModel = OptimizerModel()
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from datetime import timedelta
|
|
6
|
-
from typing import Literal
|
|
6
|
+
from typing import Literal
|
|
7
7
|
|
|
8
8
|
from pydantic import (
|
|
9
9
|
BaseModel,
|
|
@@ -33,7 +33,7 @@ class CheckpointModel(BaseModel):
|
|
|
33
33
|
save_weights_only: bool = Field(default=False)
|
|
34
34
|
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
35
35
|
|
|
36
|
-
save_last:
|
|
36
|
+
save_last: Literal[True, False, "link"] | None = Field(default=True)
|
|
37
37
|
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
38
38
|
|
|
39
39
|
save_top_k: int = Field(default=3, ge=-1, le=100)
|
|
@@ -51,13 +51,13 @@ class CheckpointModel(BaseModel):
|
|
|
51
51
|
auto_insert_metric_name: bool = Field(default=False)
|
|
52
52
|
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
53
53
|
|
|
54
|
-
every_n_train_steps:
|
|
54
|
+
every_n_train_steps: int | None = Field(default=None, ge=1, le=1000)
|
|
55
55
|
"""Number of training steps between checkpoints."""
|
|
56
56
|
|
|
57
|
-
train_time_interval:
|
|
57
|
+
train_time_interval: timedelta | None = Field(default=None)
|
|
58
58
|
"""Checkpoints are monitored at the specified time interval."""
|
|
59
59
|
|
|
60
|
-
every_n_epochs:
|
|
60
|
+
every_n_epochs: int | None = Field(default=None, ge=1, le=100)
|
|
61
61
|
"""Number of epochs between checkpoints."""
|
|
62
62
|
|
|
63
63
|
|
|
@@ -96,14 +96,14 @@ class EarlyStoppingModel(BaseModel):
|
|
|
96
96
|
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
97
97
|
`inf`."""
|
|
98
98
|
|
|
99
|
-
stopping_threshold:
|
|
99
|
+
stopping_threshold: float | None = Field(default=None)
|
|
100
100
|
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
101
101
|
|
|
102
|
-
divergence_threshold:
|
|
102
|
+
divergence_threshold: float | None = Field(default=None)
|
|
103
103
|
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
104
104
|
threshold."""
|
|
105
105
|
|
|
106
|
-
check_on_train_epoch_end:
|
|
106
|
+
check_on_train_epoch_end: bool | None = Field(default=False)
|
|
107
107
|
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
108
108
|
`False`, then the check runs at the end of the validation."""
|
|
109
109
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Annotated, Any, Literal,
|
|
4
|
+
from typing import Annotated, Any, Literal, Union
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, TypeAdapter
|
|
7
7
|
|
|
@@ -50,7 +50,7 @@ def algorithm_factory(
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def _list_spatial_augmentations(
|
|
53
|
-
augmentations:
|
|
53
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
54
54
|
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
55
55
|
"""
|
|
56
56
|
List the augmentations to apply.
|
|
@@ -105,7 +105,7 @@ def _create_unet_configuration(
|
|
|
105
105
|
n_channels_out: int,
|
|
106
106
|
independent_channels: bool,
|
|
107
107
|
use_n2v2: bool,
|
|
108
|
-
model_params:
|
|
108
|
+
model_params: dict[str, Any] | None = None,
|
|
109
109
|
) -> UNetModel:
|
|
110
110
|
"""
|
|
111
111
|
Create a dictionary with the parameters of the UNet model.
|
|
@@ -153,11 +153,11 @@ def _create_algorithm_configuration(
|
|
|
153
153
|
n_channels_in: int,
|
|
154
154
|
n_channels_out: int,
|
|
155
155
|
use_n2v2: bool = False,
|
|
156
|
-
model_params:
|
|
156
|
+
model_params: dict | None = None,
|
|
157
157
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
158
|
-
optimizer_params:
|
|
158
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
159
159
|
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
160
|
-
lr_scheduler_params:
|
|
160
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
161
161
|
) -> dict:
|
|
162
162
|
"""
|
|
163
163
|
Create a dictionary with the parameters of the algorithm model.
|
|
@@ -227,8 +227,8 @@ def _create_data_configuration(
|
|
|
227
227
|
patch_size: list[int],
|
|
228
228
|
batch_size: int,
|
|
229
229
|
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
230
|
-
train_dataloader_params:
|
|
231
|
-
val_dataloader_params:
|
|
230
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
231
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
232
232
|
) -> DataConfig:
|
|
233
233
|
"""
|
|
234
234
|
Create a dictionary with the parameters of the data model.
|
|
@@ -283,11 +283,11 @@ def _create_ng_data_configuration(
|
|
|
283
283
|
patch_size: Sequence[int],
|
|
284
284
|
batch_size: int,
|
|
285
285
|
augmentations: list[SPATIAL_TRANSFORMS_UNION],
|
|
286
|
-
patch_overlaps:
|
|
287
|
-
train_dataloader_params:
|
|
288
|
-
val_dataloader_params:
|
|
289
|
-
test_dataloader_params:
|
|
290
|
-
seed:
|
|
286
|
+
patch_overlaps: Sequence[int] | None = None,
|
|
287
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
288
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
289
|
+
test_dataloader_params: dict[str, Any] | None = None,
|
|
290
|
+
seed: int | None = None,
|
|
291
291
|
) -> NGDataConfig:
|
|
292
292
|
"""
|
|
293
293
|
Create a dictionary with the parameters of the data model.
|
|
@@ -359,7 +359,7 @@ def _create_ng_data_configuration(
|
|
|
359
359
|
def _create_training_configuration(
|
|
360
360
|
num_epochs: int,
|
|
361
361
|
logger: Literal["wandb", "tensorboard", "none"],
|
|
362
|
-
checkpoint_params:
|
|
362
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
363
363
|
) -> TrainingConfig:
|
|
364
364
|
"""
|
|
365
365
|
Create a dictionary with the parameters of the training model.
|
|
@@ -395,20 +395,20 @@ def _create_supervised_config_dict(
|
|
|
395
395
|
patch_size: list[int],
|
|
396
396
|
batch_size: int,
|
|
397
397
|
num_epochs: int,
|
|
398
|
-
augmentations:
|
|
398
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
399
399
|
independent_channels: bool = True,
|
|
400
400
|
loss: Literal["mae", "mse"] = "mae",
|
|
401
|
-
n_channels_in:
|
|
402
|
-
n_channels_out:
|
|
401
|
+
n_channels_in: int | None = None,
|
|
402
|
+
n_channels_out: int | None = None,
|
|
403
403
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
404
|
-
model_params:
|
|
404
|
+
model_params: dict | None = None,
|
|
405
405
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
406
|
-
optimizer_params:
|
|
406
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
407
407
|
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
408
|
-
lr_scheduler_params:
|
|
409
|
-
train_dataloader_params:
|
|
410
|
-
val_dataloader_params:
|
|
411
|
-
checkpoint_params:
|
|
408
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
409
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
410
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
411
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
412
412
|
) -> dict:
|
|
413
413
|
"""
|
|
414
414
|
Create a configuration for training CARE or Noise2Noise.
|
|
@@ -540,20 +540,20 @@ def create_care_configuration(
|
|
|
540
540
|
patch_size: list[int],
|
|
541
541
|
batch_size: int,
|
|
542
542
|
num_epochs: int,
|
|
543
|
-
augmentations:
|
|
543
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
544
544
|
independent_channels: bool = True,
|
|
545
545
|
loss: Literal["mae", "mse"] = "mae",
|
|
546
|
-
n_channels_in:
|
|
547
|
-
n_channels_out:
|
|
546
|
+
n_channels_in: int | None = None,
|
|
547
|
+
n_channels_out: int | None = None,
|
|
548
548
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
549
|
-
model_params:
|
|
549
|
+
model_params: dict | None = None,
|
|
550
550
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
551
|
-
optimizer_params:
|
|
551
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
552
552
|
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
553
|
-
lr_scheduler_params:
|
|
554
|
-
train_dataloader_params:
|
|
555
|
-
val_dataloader_params:
|
|
556
|
-
checkpoint_params:
|
|
553
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
554
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
555
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
556
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
557
557
|
) -> Configuration:
|
|
558
558
|
"""
|
|
559
559
|
Create a configuration for training CARE.
|
|
@@ -756,20 +756,20 @@ def create_n2n_configuration(
|
|
|
756
756
|
patch_size: list[int],
|
|
757
757
|
batch_size: int,
|
|
758
758
|
num_epochs: int,
|
|
759
|
-
augmentations:
|
|
759
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
760
760
|
independent_channels: bool = True,
|
|
761
761
|
loss: Literal["mae", "mse"] = "mae",
|
|
762
|
-
n_channels_in:
|
|
763
|
-
n_channels_out:
|
|
762
|
+
n_channels_in: int | None = None,
|
|
763
|
+
n_channels_out: int | None = None,
|
|
764
764
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
765
|
-
model_params:
|
|
765
|
+
model_params: dict | None = None,
|
|
766
766
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
767
|
-
optimizer_params:
|
|
767
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
768
768
|
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
769
|
-
lr_scheduler_params:
|
|
770
|
-
train_dataloader_params:
|
|
771
|
-
val_dataloader_params:
|
|
772
|
-
checkpoint_params:
|
|
769
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
770
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
771
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
772
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
773
773
|
) -> Configuration:
|
|
774
774
|
"""
|
|
775
775
|
Create a configuration for training Noise2Noise.
|
|
@@ -972,23 +972,23 @@ def create_n2v_configuration(
|
|
|
972
972
|
patch_size: list[int],
|
|
973
973
|
batch_size: int,
|
|
974
974
|
num_epochs: int,
|
|
975
|
-
augmentations:
|
|
975
|
+
augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]] | None = None,
|
|
976
976
|
independent_channels: bool = True,
|
|
977
977
|
use_n2v2: bool = False,
|
|
978
|
-
n_channels:
|
|
978
|
+
n_channels: int | None = None,
|
|
979
979
|
roi_size: int = 11,
|
|
980
980
|
masked_pixel_percentage: float = 0.2,
|
|
981
981
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
982
982
|
struct_n2v_span: int = 5,
|
|
983
983
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
984
|
-
model_params:
|
|
984
|
+
model_params: dict | None = None,
|
|
985
985
|
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
986
|
-
optimizer_params:
|
|
986
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
987
987
|
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
988
|
-
lr_scheduler_params:
|
|
989
|
-
train_dataloader_params:
|
|
990
|
-
val_dataloader_params:
|
|
991
|
-
checkpoint_params:
|
|
988
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
989
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
990
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
991
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
992
992
|
) -> Configuration:
|
|
993
993
|
"""
|
|
994
994
|
Create a configuration for training Noise2Void.
|
|
@@ -6,7 +6,7 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from pprint import pformat
|
|
9
|
-
from typing import Annotated, Any, Literal,
|
|
9
|
+
from typing import Annotated, Any, Literal, Union
|
|
10
10
|
from warnings import warn
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
@@ -109,22 +109,16 @@ class DataConfig(BaseModel):
|
|
|
109
109
|
"""Batch size for training."""
|
|
110
110
|
|
|
111
111
|
# Optional fields
|
|
112
|
-
image_means:
|
|
113
|
-
default=None, min_length=0, max_length=32
|
|
114
|
-
)
|
|
112
|
+
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
115
113
|
"""Means of the data across channels, used for normalization."""
|
|
116
114
|
|
|
117
|
-
image_stds:
|
|
115
|
+
image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
118
116
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
119
117
|
|
|
120
|
-
target_means:
|
|
121
|
-
default=None, min_length=0, max_length=32
|
|
122
|
-
)
|
|
118
|
+
target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
123
119
|
"""Means of the target data across channels, used for normalization."""
|
|
124
120
|
|
|
125
|
-
target_stds:
|
|
126
|
-
default=None, min_length=0, max_length=32
|
|
127
|
-
)
|
|
121
|
+
target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
128
122
|
"""Standard deviations of the target data across channels, used for
|
|
129
123
|
normalization."""
|
|
130
124
|
|
|
@@ -388,8 +382,8 @@ class DataConfig(BaseModel):
|
|
|
388
382
|
self,
|
|
389
383
|
image_means: Union[NDArray, tuple, list, None],
|
|
390
384
|
image_stds: Union[NDArray, tuple, list, None],
|
|
391
|
-
target_means:
|
|
392
|
-
target_stds:
|
|
385
|
+
target_means: Union[NDArray, tuple, list, None] | None = None,
|
|
386
|
+
target_stds: Union[NDArray, tuple, list, None] | None = None,
|
|
393
387
|
) -> None:
|
|
394
388
|
"""
|
|
395
389
|
Set mean and standard deviation of the data across channels.
|