careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
careamics/config/__init__.py
CHANGED
|
@@ -14,9 +14,7 @@ __all__ = [
|
|
|
14
14
|
"create_care_configuration",
|
|
15
15
|
"register_model",
|
|
16
16
|
"CustomModel",
|
|
17
|
-
"create_inference_configuration",
|
|
18
17
|
"clear_custom_models",
|
|
19
|
-
"ConfigurationInformation",
|
|
20
18
|
]
|
|
21
19
|
|
|
22
20
|
from .algorithm_model import AlgorithmConfig
|
|
@@ -24,7 +22,6 @@ from .architectures import CustomModel, clear_custom_models, register_model
|
|
|
24
22
|
from .callback_model import CheckpointModel
|
|
25
23
|
from .configuration_factory import (
|
|
26
24
|
create_care_configuration,
|
|
27
|
-
create_inference_configuration,
|
|
28
25
|
create_n2n_configuration,
|
|
29
26
|
create_n2v_configuration,
|
|
30
27
|
)
|
|
@@ -93,12 +93,20 @@ class AlgorithmConfig(BaseModel):
|
|
|
93
93
|
|
|
94
94
|
# Mandatory fields
|
|
95
95
|
algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
|
|
96
|
+
"""Name of the algorithm, as defined in SupportedAlgorithm."""
|
|
97
|
+
|
|
96
98
|
loss: Literal["n2v", "mae", "mse"]
|
|
99
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
100
|
+
|
|
97
101
|
model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
|
|
102
|
+
"""Model architecture to use, defined in SupportedArchitecture."""
|
|
98
103
|
|
|
99
104
|
# Optional fields
|
|
100
105
|
optimizer: OptimizerModel = OptimizerModel()
|
|
106
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
107
|
+
|
|
101
108
|
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
109
|
+
"""Learning rate scheduler to use, defined in SupportedScheduler."""
|
|
102
110
|
|
|
103
111
|
@model_validator(mode="after")
|
|
104
112
|
def algorithm_cross_validation(self: Self) -> Self:
|
|
@@ -134,21 +142,6 @@ class AlgorithmConfig(BaseModel):
|
|
|
134
142
|
"sure that `in_channels` and `num_classes` are the same."
|
|
135
143
|
)
|
|
136
144
|
|
|
137
|
-
# N2N
|
|
138
|
-
if self.algorithm == "n2n":
|
|
139
|
-
# n2n is only compatible with the UNet model
|
|
140
|
-
if not isinstance(self.model, UNetModel):
|
|
141
|
-
raise ValueError(
|
|
142
|
-
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# n2n requires the number of input and output channels to be the same
|
|
146
|
-
if self.model.in_channels != self.model.num_classes:
|
|
147
|
-
raise ValueError(
|
|
148
|
-
"N2N requires the same number of input and output channels. Make "
|
|
149
|
-
"sure that `in_channels` and `num_classes` are the same."
|
|
150
|
-
)
|
|
151
|
-
|
|
152
145
|
if self.algorithm == "care" or self.algorithm == "n2n":
|
|
153
146
|
if self.loss == "n2v":
|
|
154
147
|
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
from pydantic import ConfigDict, field_validator, model_validator
|
|
9
9
|
from torch.nn import Module
|
|
@@ -72,9 +72,11 @@ class CustomModel(ArchitectureModel):
|
|
|
72
72
|
|
|
73
73
|
# discriminator used for choosing the pydantic model in Model
|
|
74
74
|
architecture: Literal["Custom"]
|
|
75
|
+
"""Name of the architecture."""
|
|
75
76
|
|
|
76
77
|
# name of the custom model
|
|
77
78
|
name: str
|
|
79
|
+
"""Name of the custom model."""
|
|
78
80
|
|
|
79
81
|
@field_validator("name")
|
|
80
82
|
@classmethod
|
|
@@ -136,7 +138,7 @@ class CustomModel(ArchitectureModel):
|
|
|
136
138
|
"""
|
|
137
139
|
return pformat(self.model_dump())
|
|
138
140
|
|
|
139
|
-
def model_dump(self, **kwargs: Any) ->
|
|
141
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
140
142
|
"""Dump the model configuration.
|
|
141
143
|
|
|
142
144
|
Parameters
|
|
@@ -146,7 +148,7 @@ class CustomModel(ArchitectureModel):
|
|
|
146
148
|
|
|
147
149
|
Returns
|
|
148
150
|
-------
|
|
149
|
-
|
|
151
|
+
dict[str, Any]
|
|
150
152
|
Model configuration.
|
|
151
153
|
"""
|
|
152
154
|
model_dict = super().model_dump()
|
|
@@ -29,19 +29,38 @@ class UNetModel(ArchitectureModel):
|
|
|
29
29
|
|
|
30
30
|
# discriminator used for choosing the pydantic model in Model
|
|
31
31
|
architecture: Literal["UNet"]
|
|
32
|
+
"""Name of the architecture."""
|
|
32
33
|
|
|
33
34
|
# parameters
|
|
34
35
|
# validate_defaults allow ignoring default values in the dump if they were not set
|
|
35
36
|
conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
|
|
37
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
38
|
+
|
|
36
39
|
num_classes: int = Field(default=1, ge=1, validate_default=True)
|
|
40
|
+
"""Number of classes or channels in the model output."""
|
|
41
|
+
|
|
37
42
|
in_channels: int = Field(default=1, ge=1, validate_default=True)
|
|
43
|
+
"""Number of channels in the input to the model."""
|
|
44
|
+
|
|
38
45
|
depth: int = Field(default=2, ge=1, le=10, validate_default=True)
|
|
46
|
+
"""Number of levels in the UNet."""
|
|
47
|
+
|
|
39
48
|
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
|
+
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
|
+
|
|
40
51
|
final_activation: Literal[
|
|
41
52
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
42
53
|
] = Field(default="None", validate_default=True)
|
|
54
|
+
"""Final activation function."""
|
|
55
|
+
|
|
43
56
|
n2v2: bool = Field(default=False, validate_default=True)
|
|
57
|
+
"""Whether to use N2V2 architecture modifications, with blur pool layers and fewer
|
|
58
|
+
skip connections.
|
|
59
|
+
"""
|
|
60
|
+
|
|
44
61
|
independent_channels: bool = Field(default=True, validate_default=True)
|
|
62
|
+
"""Whether information is processed independently in each channel, used to train
|
|
63
|
+
channels independently."""
|
|
45
64
|
|
|
46
65
|
@field_validator("num_channels_init")
|
|
47
66
|
@classmethod
|
|
@@ -13,69 +13,111 @@ from pydantic import (
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class CheckpointModel(BaseModel):
|
|
16
|
-
"""Checkpoint saving callback Pydantic model.
|
|
16
|
+
"""Checkpoint saving callback Pydantic model.
|
|
17
|
+
|
|
18
|
+
The parameters corresponds to those of
|
|
19
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
20
|
+
|
|
21
|
+
See:
|
|
22
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
|
+
"""
|
|
17
24
|
|
|
18
25
|
model_config = ConfigDict(
|
|
19
26
|
validate_assignment=True,
|
|
20
27
|
)
|
|
21
28
|
|
|
22
29
|
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
30
|
+
"""Quantity to monitor."""
|
|
31
|
+
|
|
23
32
|
verbose: bool = Field(default=False, validate_default=True)
|
|
33
|
+
"""Verbosity mode."""
|
|
34
|
+
|
|
24
35
|
save_weights_only: bool = Field(default=False, validate_default=True)
|
|
36
|
+
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
37
|
+
|
|
38
|
+
save_last: Optional[Literal[True, False, "link"]] = Field(
|
|
39
|
+
default=True, validate_default=True
|
|
40
|
+
)
|
|
41
|
+
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
42
|
+
|
|
43
|
+
save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
44
|
+
"""If `save_top_k == kz, the best k models according to the quantity monitored
|
|
45
|
+
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
|
+
all models are saved."""
|
|
47
|
+
|
|
25
48
|
mode: Literal["min", "max"] = Field(default="min", validate_default=True)
|
|
49
|
+
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
|
+
save file is made based on either the maximization or the minimization of the
|
|
51
|
+
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
|
+
be 'min', etc.
|
|
53
|
+
"""
|
|
54
|
+
|
|
26
55
|
auto_insert_metric_name: bool = Field(default=False, validate_default=True)
|
|
56
|
+
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
|
+
|
|
27
58
|
every_n_train_steps: Optional[int] = Field(
|
|
28
59
|
default=None, ge=1, le=10, validate_default=True
|
|
29
60
|
)
|
|
61
|
+
"""Number of training steps between checkpoints."""
|
|
62
|
+
|
|
30
63
|
train_time_interval: Optional[timedelta] = Field(
|
|
31
64
|
default=None, validate_default=True
|
|
32
65
|
)
|
|
66
|
+
"""Checkpoints are monitored at the specified time interval."""
|
|
67
|
+
|
|
33
68
|
every_n_epochs: Optional[int] = Field(
|
|
34
69
|
default=None, ge=1, le=10, validate_default=True
|
|
35
70
|
)
|
|
36
|
-
|
|
37
|
-
default=True, validate_default=True
|
|
38
|
-
)
|
|
39
|
-
save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
71
|
+
"""Number of epochs between checkpoints."""
|
|
40
72
|
|
|
41
73
|
|
|
42
74
|
class EarlyStoppingModel(BaseModel):
|
|
43
|
-
"""Early stopping callback Pydantic model.
|
|
75
|
+
"""Early stopping callback Pydantic model.
|
|
76
|
+
|
|
77
|
+
The parameters corresponds to those of
|
|
78
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
79
|
+
|
|
80
|
+
See:
|
|
81
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
|
|
82
|
+
"""
|
|
44
83
|
|
|
45
84
|
model_config = ConfigDict(
|
|
46
85
|
validate_assignment=True,
|
|
47
86
|
)
|
|
48
87
|
|
|
49
88
|
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
89
|
+
"""Quantity to monitor."""
|
|
90
|
+
|
|
91
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
|
|
92
|
+
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
93
|
+
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
94
|
+
|
|
50
95
|
patience: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
96
|
+
"""Number of checks with no improvement after which training will be stopped."""
|
|
97
|
+
|
|
98
|
+
verbose: bool = Field(default=False, validate_default=True)
|
|
99
|
+
"""Verbosity mode."""
|
|
100
|
+
|
|
51
101
|
mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
|
|
52
|
-
|
|
102
|
+
"""One of {min, max, auto}."""
|
|
103
|
+
|
|
53
104
|
check_finite: bool = Field(default=True, validate_default=True)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
default=10.0, ge=0.0, le=1e6, validate_default=True
|
|
67
|
-
)
|
|
68
|
-
auto_lr_find_min_lr: float = Field(
|
|
69
|
-
default=1e-8, ge=0.0, le=1e6, validate_default=True
|
|
70
|
-
)
|
|
71
|
-
auto_lr_find_num_training: int = Field(
|
|
72
|
-
default=100, ge=1, le=1e6, validate_default=True
|
|
73
|
-
)
|
|
74
|
-
auto_lr_find_divergence_threshold: float = Field(
|
|
75
|
-
default=5.0, ge=0.0, le=1e6, validate_default=True
|
|
76
|
-
)
|
|
77
|
-
auto_lr_find_accumulate_grad_batches: int = Field(
|
|
78
|
-
default=1, ge=1, le=1e6, validate_default=True
|
|
105
|
+
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
106
|
+
`inf`."""
|
|
107
|
+
|
|
108
|
+
stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
109
|
+
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
110
|
+
|
|
111
|
+
divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
112
|
+
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
113
|
+
threshold."""
|
|
114
|
+
|
|
115
|
+
check_on_train_epoch_end: Optional[bool] = Field(
|
|
116
|
+
default=False, validate_default=True
|
|
79
117
|
)
|
|
80
|
-
|
|
81
|
-
|
|
118
|
+
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
119
|
+
`False`, then the check runs at the end of the validation."""
|
|
120
|
+
|
|
121
|
+
log_rank_zero_only: bool = Field(default=False, validate_default=True)
|
|
122
|
+
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
123
|
+
process."""
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
4
4
|
|
|
5
5
|
from .algorithm_model import AlgorithmConfig
|
|
6
6
|
from .architectures import UNetModel
|
|
7
7
|
from .configuration_model import Configuration
|
|
8
8
|
from .data_model import DataConfig
|
|
9
|
-
from .inference_model import InferenceConfig
|
|
10
9
|
from .support import (
|
|
11
10
|
SupportedAlgorithm,
|
|
12
11
|
SupportedArchitecture,
|
|
@@ -107,9 +106,6 @@ def _create_supervised_configuration(
|
|
|
107
106
|
# augmentations
|
|
108
107
|
if use_augmentations:
|
|
109
108
|
transforms: List[Dict[str, Any]] = [
|
|
110
|
-
{
|
|
111
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
112
|
-
},
|
|
113
109
|
{
|
|
114
110
|
"name": SupportedTransform.XY_FLIP.value,
|
|
115
111
|
},
|
|
@@ -118,11 +114,7 @@ def _create_supervised_configuration(
|
|
|
118
114
|
},
|
|
119
115
|
]
|
|
120
116
|
else:
|
|
121
|
-
transforms = [
|
|
122
|
-
{
|
|
123
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
124
|
-
},
|
|
125
|
-
]
|
|
117
|
+
transforms = []
|
|
126
118
|
|
|
127
119
|
# data model
|
|
128
120
|
data = DataConfig(
|
|
@@ -250,7 +242,8 @@ def create_n2n_configuration(
|
|
|
250
242
|
use_augmentations: bool = True,
|
|
251
243
|
independent_channels: bool = False,
|
|
252
244
|
loss: Literal["mae", "mse"] = "mae",
|
|
253
|
-
|
|
245
|
+
n_channels_in: int = 1,
|
|
246
|
+
n_channels_out: int = -1,
|
|
254
247
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
255
248
|
model_kwargs: Optional[dict] = None,
|
|
256
249
|
) -> Configuration:
|
|
@@ -260,10 +253,13 @@ def create_n2n_configuration(
|
|
|
260
253
|
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
261
254
|
2.
|
|
262
255
|
|
|
263
|
-
If "C" is present in `axes`, then you need to set `
|
|
256
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
264
257
|
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
265
258
|
`axes`.
|
|
266
259
|
|
|
260
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
261
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
262
|
+
|
|
267
263
|
By default, all channels are trained together. To train all channels independently,
|
|
268
264
|
set `independent_channels` to True.
|
|
269
265
|
|
|
@@ -290,8 +286,10 @@ def create_n2n_configuration(
|
|
|
290
286
|
Whether to train all channels independently, by default False.
|
|
291
287
|
loss : Literal["mae", "mse"], optional
|
|
292
288
|
Loss function to use, by default "mae".
|
|
293
|
-
|
|
294
|
-
Number of channels
|
|
289
|
+
n_channels_in : int, optional
|
|
290
|
+
Number of channels in, by default 1.
|
|
291
|
+
n_channels_out : int, optional
|
|
292
|
+
Number of channels out, by default -1.
|
|
295
293
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
296
294
|
Logger to use, by default "none".
|
|
297
295
|
model_kwargs : dict, optional
|
|
@@ -302,6 +300,9 @@ def create_n2n_configuration(
|
|
|
302
300
|
Configuration
|
|
303
301
|
Configuration for training Noise2Noise.
|
|
304
302
|
"""
|
|
303
|
+
if n_channels_out == -1:
|
|
304
|
+
n_channels_out = n_channels_in
|
|
305
|
+
|
|
305
306
|
return _create_supervised_configuration(
|
|
306
307
|
algorithm="n2n",
|
|
307
308
|
experiment_name=experiment_name,
|
|
@@ -313,8 +314,8 @@ def create_n2n_configuration(
|
|
|
313
314
|
use_augmentations=use_augmentations,
|
|
314
315
|
independent_channels=independent_channels,
|
|
315
316
|
loss=loss,
|
|
316
|
-
n_channels_in=
|
|
317
|
-
n_channels_out=
|
|
317
|
+
n_channels_in=n_channels_in,
|
|
318
|
+
n_channels_out=n_channels_out,
|
|
318
319
|
logger=logger,
|
|
319
320
|
model_kwargs=model_kwargs,
|
|
320
321
|
)
|
|
@@ -522,9 +523,6 @@ def create_n2v_configuration(
|
|
|
522
523
|
# augmentations
|
|
523
524
|
if use_augmentations:
|
|
524
525
|
transforms: List[Dict[str, Any]] = [
|
|
525
|
-
{
|
|
526
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
527
|
-
},
|
|
528
526
|
{
|
|
529
527
|
"name": SupportedTransform.XY_FLIP.value,
|
|
530
528
|
},
|
|
@@ -533,11 +531,7 @@ def create_n2v_configuration(
|
|
|
533
531
|
},
|
|
534
532
|
]
|
|
535
533
|
else:
|
|
536
|
-
transforms = [
|
|
537
|
-
{
|
|
538
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
539
|
-
},
|
|
540
|
-
]
|
|
534
|
+
transforms = []
|
|
541
535
|
|
|
542
536
|
# n2v2 and structn2v
|
|
543
537
|
nv2_transform = {
|
|
@@ -579,77 +573,3 @@ def create_n2v_configuration(
|
|
|
579
573
|
)
|
|
580
574
|
|
|
581
575
|
return configuration
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
def create_inference_configuration(
|
|
585
|
-
configuration: Configuration,
|
|
586
|
-
tile_size: Optional[Tuple[int, ...]] = None,
|
|
587
|
-
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
588
|
-
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
589
|
-
axes: Optional[str] = None,
|
|
590
|
-
tta_transforms: bool = True,
|
|
591
|
-
batch_size: Optional[int] = 1,
|
|
592
|
-
) -> InferenceConfig:
|
|
593
|
-
"""
|
|
594
|
-
Create a configuration for inference with N2V.
|
|
595
|
-
|
|
596
|
-
If not provided, `data_type` and `axes` are taken from the training
|
|
597
|
-
configuration.
|
|
598
|
-
|
|
599
|
-
Parameters
|
|
600
|
-
----------
|
|
601
|
-
configuration : Configuration
|
|
602
|
-
Global configuration.
|
|
603
|
-
tile_size : Tuple[int, ...], optional
|
|
604
|
-
Size of the tiles.
|
|
605
|
-
tile_overlap : Tuple[int, ...], optional
|
|
606
|
-
Overlap of the tiles.
|
|
607
|
-
data_type : str, optional
|
|
608
|
-
Type of the data, by default "tiff".
|
|
609
|
-
axes : str, optional
|
|
610
|
-
Axes of the data, by default "YX".
|
|
611
|
-
tta_transforms : bool, optional
|
|
612
|
-
Whether to apply test-time augmentations, by default True.
|
|
613
|
-
batch_size : int, optional
|
|
614
|
-
Batch size, by default 1.
|
|
615
|
-
|
|
616
|
-
Returns
|
|
617
|
-
-------
|
|
618
|
-
InferenceConfiguration
|
|
619
|
-
Configuration used to configure CAREamicsPredictData.
|
|
620
|
-
"""
|
|
621
|
-
if configuration.data_config.mean is None or configuration.data_config.std is None:
|
|
622
|
-
raise ValueError("Mean and std must be provided in the configuration.")
|
|
623
|
-
|
|
624
|
-
# tile size for UNets
|
|
625
|
-
if tile_size is not None:
|
|
626
|
-
model = configuration.algorithm_config.model
|
|
627
|
-
|
|
628
|
-
if model.architecture == SupportedArchitecture.UNET.value:
|
|
629
|
-
# tile size must be equal to k*2^n, where n is the number of pooling layers
|
|
630
|
-
# (equal to the depth) and k is an integer
|
|
631
|
-
depth = model.depth
|
|
632
|
-
tile_increment = 2**depth
|
|
633
|
-
|
|
634
|
-
for i, t in enumerate(tile_size):
|
|
635
|
-
if t % tile_increment != 0:
|
|
636
|
-
raise ValueError(
|
|
637
|
-
f"Tile size must be divisible by {tile_increment} along all "
|
|
638
|
-
f"axes (got {t} for axis {i}). If your image size is smaller "
|
|
639
|
-
f"along one axis (e.g. Z), consider padding the image."
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
# tile overlaps must be specified
|
|
643
|
-
if tile_overlap is None:
|
|
644
|
-
raise ValueError("Tile overlap must be specified.")
|
|
645
|
-
|
|
646
|
-
return InferenceConfig(
|
|
647
|
-
data_type=data_type or configuration.data_config.data_type,
|
|
648
|
-
tile_size=tile_size,
|
|
649
|
-
tile_overlap=tile_overlap,
|
|
650
|
-
axes=axes or configuration.data_config.axes,
|
|
651
|
-
mean=configuration.data_config.mean,
|
|
652
|
-
std=configuration.data_config.std,
|
|
653
|
-
tta_transforms=tta_transforms,
|
|
654
|
-
batch_size=batch_size,
|
|
655
|
-
)
|
|
@@ -5,11 +5,11 @@ from __future__ import annotations
|
|
|
5
5
|
import re
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from pprint import pformat
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Literal, Union
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
12
|
-
from pydantic import BaseModel, ConfigDict,
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
|
13
13
|
from typing_extensions import Self
|
|
14
14
|
|
|
15
15
|
from .algorithm_model import AlgorithmConfig
|
|
@@ -147,20 +147,25 @@ class Configuration(BaseModel):
|
|
|
147
147
|
)
|
|
148
148
|
|
|
149
149
|
# version
|
|
150
|
-
version: Literal["0.1.0"] =
|
|
151
|
-
|
|
152
|
-
)
|
|
150
|
+
version: Literal["0.1.0"] = "0.1.0"
|
|
151
|
+
"""CAREamics configuration version."""
|
|
153
152
|
|
|
154
153
|
# required parameters
|
|
155
|
-
experiment_name: str
|
|
156
|
-
|
|
157
|
-
)
|
|
154
|
+
experiment_name: str
|
|
155
|
+
"""Name of the experiment, used to name logs and checkpoints."""
|
|
158
156
|
|
|
159
157
|
# Sub-configurations
|
|
160
158
|
algorithm_config: AlgorithmConfig
|
|
159
|
+
"""Algorithm configuration, holding all parameters required to configure the
|
|
160
|
+
model."""
|
|
161
161
|
|
|
162
162
|
data_config: DataConfig
|
|
163
|
+
"""Data configuration, holding all parameters required to configure the training
|
|
164
|
+
data loader."""
|
|
165
|
+
|
|
163
166
|
training_config: TrainingConfig
|
|
167
|
+
"""Training configuration, holding all parameters required to configure the
|
|
168
|
+
training process."""
|
|
164
169
|
|
|
165
170
|
@field_validator("experiment_name")
|
|
166
171
|
@classmethod
|
|
@@ -269,7 +274,7 @@ class Configuration(BaseModel):
|
|
|
269
274
|
"""
|
|
270
275
|
return pformat(self.model_dump())
|
|
271
276
|
|
|
272
|
-
def set_3D(self, is_3D: bool, axes: str, patch_size:
|
|
277
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
|
|
273
278
|
"""
|
|
274
279
|
Set 3D flag and axes.
|
|
275
280
|
|
|
@@ -279,7 +284,7 @@ class Configuration(BaseModel):
|
|
|
279
284
|
Whether the algorithm is 3D or not.
|
|
280
285
|
axes : str
|
|
281
286
|
Axes of the data.
|
|
282
|
-
patch_size :
|
|
287
|
+
patch_size : list[int]
|
|
283
288
|
Patch size.
|
|
284
289
|
"""
|
|
285
290
|
# set the flag and axes (this will not trigger validation at the config level)
|
|
@@ -389,7 +394,7 @@ class Configuration(BaseModel):
|
|
|
389
394
|
|
|
390
395
|
return ""
|
|
391
396
|
|
|
392
|
-
def get_algorithm_citations(self) ->
|
|
397
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
393
398
|
"""
|
|
394
399
|
Return a list of citation entries of the current algorithm.
|
|
395
400
|
|
|
@@ -455,13 +460,13 @@ class Configuration(BaseModel):
|
|
|
455
460
|
|
|
456
461
|
return ""
|
|
457
462
|
|
|
458
|
-
def get_algorithm_keywords(self) ->
|
|
463
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
459
464
|
"""
|
|
460
465
|
Get algorithm keywords.
|
|
461
466
|
|
|
462
467
|
Returns
|
|
463
468
|
-------
|
|
464
|
-
|
|
469
|
+
list[str]
|
|
465
470
|
List of keywords.
|
|
466
471
|
"""
|
|
467
472
|
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
@@ -491,8 +496,8 @@ class Configuration(BaseModel):
|
|
|
491
496
|
self,
|
|
492
497
|
exclude_defaults: bool = False,
|
|
493
498
|
exclude_none: bool = True,
|
|
494
|
-
**kwargs:
|
|
495
|
-
) ->
|
|
499
|
+
**kwargs: dict,
|
|
500
|
+
) -> dict:
|
|
496
501
|
"""
|
|
497
502
|
Override model_dump method in order to set default values.
|
|
498
503
|
|
|
@@ -503,7 +508,7 @@ class Configuration(BaseModel):
|
|
|
503
508
|
True.
|
|
504
509
|
exclude_none : bool, optional
|
|
505
510
|
Whether to exclude fields with None values or not, by default True.
|
|
506
|
-
**kwargs :
|
|
511
|
+
**kwargs : dict
|
|
507
512
|
Keyword arguments.
|
|
508
513
|
|
|
509
514
|
Returns
|
|
@@ -524,7 +529,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
|
524
529
|
|
|
525
530
|
Parameters
|
|
526
531
|
----------
|
|
527
|
-
path :
|
|
532
|
+
path : str or Path
|
|
528
533
|
Path to the configuration.
|
|
529
534
|
|
|
530
535
|
Returns
|
|
@@ -556,7 +561,7 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
556
561
|
----------
|
|
557
562
|
config : Configuration
|
|
558
563
|
Configuration to save.
|
|
559
|
-
path :
|
|
564
|
+
path : str or Path
|
|
560
565
|
Path to a existing folder in which to save the configuration or to an existing
|
|
561
566
|
configuration file.
|
|
562
567
|
|