careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Literal, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CheckpointModel(BaseModel):
|
|
16
|
+
"""Checkpoint saving callback Pydantic model.
|
|
17
|
+
|
|
18
|
+
The parameters corresponds to those of
|
|
19
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
20
|
+
|
|
21
|
+
See:
|
|
22
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(
|
|
26
|
+
validate_assignment=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
30
|
+
"""Quantity to monitor."""
|
|
31
|
+
|
|
32
|
+
verbose: bool = Field(default=False, validate_default=True)
|
|
33
|
+
"""Verbosity mode."""
|
|
34
|
+
|
|
35
|
+
save_weights_only: bool = Field(default=False, validate_default=True)
|
|
36
|
+
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
37
|
+
|
|
38
|
+
save_last: Optional[Literal[True, False, "link"]] = Field(
|
|
39
|
+
default=True, validate_default=True
|
|
40
|
+
)
|
|
41
|
+
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
42
|
+
|
|
43
|
+
save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
44
|
+
"""If `save_top_k == kz, the best k models according to the quantity monitored
|
|
45
|
+
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
|
+
all models are saved."""
|
|
47
|
+
|
|
48
|
+
mode: Literal["min", "max"] = Field(default="min", validate_default=True)
|
|
49
|
+
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
|
+
save file is made based on either the maximization or the minimization of the
|
|
51
|
+
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
|
+
be 'min', etc.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
auto_insert_metric_name: bool = Field(default=False, validate_default=True)
|
|
56
|
+
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
|
+
|
|
58
|
+
every_n_train_steps: Optional[int] = Field(
|
|
59
|
+
default=None, ge=1, le=10, validate_default=True
|
|
60
|
+
)
|
|
61
|
+
"""Number of training steps between checkpoints."""
|
|
62
|
+
|
|
63
|
+
train_time_interval: Optional[timedelta] = Field(
|
|
64
|
+
default=None, validate_default=True
|
|
65
|
+
)
|
|
66
|
+
"""Checkpoints are monitored at the specified time interval."""
|
|
67
|
+
|
|
68
|
+
every_n_epochs: Optional[int] = Field(
|
|
69
|
+
default=None, ge=1, le=10, validate_default=True
|
|
70
|
+
)
|
|
71
|
+
"""Number of epochs between checkpoints."""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class EarlyStoppingModel(BaseModel):
|
|
75
|
+
"""Early stopping callback Pydantic model.
|
|
76
|
+
|
|
77
|
+
The parameters corresponds to those of
|
|
78
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
79
|
+
|
|
80
|
+
See:
|
|
81
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model_config = ConfigDict(
|
|
85
|
+
validate_assignment=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
89
|
+
"""Quantity to monitor."""
|
|
90
|
+
|
|
91
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
|
|
92
|
+
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
93
|
+
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
94
|
+
|
|
95
|
+
patience: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
96
|
+
"""Number of checks with no improvement after which training will be stopped."""
|
|
97
|
+
|
|
98
|
+
verbose: bool = Field(default=False, validate_default=True)
|
|
99
|
+
"""Verbosity mode."""
|
|
100
|
+
|
|
101
|
+
mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
|
|
102
|
+
"""One of {min, max, auto}."""
|
|
103
|
+
|
|
104
|
+
check_finite: bool = Field(default=True, validate_default=True)
|
|
105
|
+
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
106
|
+
`inf`."""
|
|
107
|
+
|
|
108
|
+
stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
109
|
+
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
110
|
+
|
|
111
|
+
divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
112
|
+
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
113
|
+
threshold."""
|
|
114
|
+
|
|
115
|
+
check_on_train_epoch_end: Optional[bool] = Field(
|
|
116
|
+
default=False, validate_default=True
|
|
117
|
+
)
|
|
118
|
+
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
119
|
+
`False`, then the check runs at the end of the validation."""
|
|
120
|
+
|
|
121
|
+
log_rank_zero_only: bool = Field(default=False, validate_default=True)
|
|
122
|
+
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
123
|
+
process."""
|
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
"""Convenience functions to create configurations for training and inference."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
4
|
+
|
|
5
|
+
from .architectures import UNetModel
|
|
6
|
+
from .configuration_model import Configuration
|
|
7
|
+
from .data_model import DataConfig
|
|
8
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
9
|
+
from .support import (
|
|
10
|
+
SupportedAlgorithm,
|
|
11
|
+
SupportedArchitecture,
|
|
12
|
+
SupportedLoss,
|
|
13
|
+
SupportedPixelManipulation,
|
|
14
|
+
SupportedTransform,
|
|
15
|
+
)
|
|
16
|
+
from .training_model import TrainingConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# TODO rename ?
|
|
20
|
+
def _create_supervised_configuration(
|
|
21
|
+
algorithm_type: Literal["fcn"],
|
|
22
|
+
algorithm: Literal["care", "n2n"],
|
|
23
|
+
experiment_name: str,
|
|
24
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
25
|
+
axes: str,
|
|
26
|
+
patch_size: List[int],
|
|
27
|
+
batch_size: int,
|
|
28
|
+
num_epochs: int,
|
|
29
|
+
use_augmentations: bool = True,
|
|
30
|
+
independent_channels: bool = False,
|
|
31
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
32
|
+
n_channels_in: int = 1,
|
|
33
|
+
n_channels_out: int = 1,
|
|
34
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
35
|
+
model_kwargs: Optional[dict] = None,
|
|
36
|
+
) -> Configuration:
|
|
37
|
+
"""
|
|
38
|
+
Create a configuration for training CARE or Noise2Noise.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
algorithm_type : Literal["fcn"]
|
|
43
|
+
Type of the algorithm.
|
|
44
|
+
algorithm : Literal["care", "n2n"]
|
|
45
|
+
Algorithm to use.
|
|
46
|
+
experiment_name : str
|
|
47
|
+
Name of the experiment.
|
|
48
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
49
|
+
Type of the data.
|
|
50
|
+
axes : str
|
|
51
|
+
Axes of the data (e.g. SYX).
|
|
52
|
+
patch_size : List[int]
|
|
53
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
54
|
+
batch_size : int
|
|
55
|
+
Batch size.
|
|
56
|
+
num_epochs : int
|
|
57
|
+
Number of epochs.
|
|
58
|
+
use_augmentations : bool, optional
|
|
59
|
+
Whether to use augmentations, by default True.
|
|
60
|
+
independent_channels : bool, optional
|
|
61
|
+
Whether to train all channels independently, by default False.
|
|
62
|
+
loss : Literal["mae", "mse"], optional
|
|
63
|
+
Loss function to use, by default "mae".
|
|
64
|
+
n_channels_in : int, optional
|
|
65
|
+
Number of channels in, by default 1.
|
|
66
|
+
n_channels_out : int, optional
|
|
67
|
+
Number of channels out, by default 1.
|
|
68
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
69
|
+
Logger to use, by default "none".
|
|
70
|
+
model_kwargs : dict, optional
|
|
71
|
+
UNetModel parameters, by default {}.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
Configuration
|
|
76
|
+
Configuration for training CARE or Noise2Noise.
|
|
77
|
+
"""
|
|
78
|
+
# if there are channels, we need to specify their number
|
|
79
|
+
if "C" in axes and n_channels_in == 1:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Number of channels in must be specified when using channels "
|
|
82
|
+
f"(got {n_channels_in} channel)."
|
|
83
|
+
)
|
|
84
|
+
elif "C" not in axes and n_channels_in > 1:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
87
|
+
f"(got {n_channels_in} channels)."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# model
|
|
91
|
+
if model_kwargs is None:
|
|
92
|
+
model_kwargs = {}
|
|
93
|
+
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
94
|
+
model_kwargs["in_channels"] = n_channels_in
|
|
95
|
+
model_kwargs["num_classes"] = n_channels_out
|
|
96
|
+
model_kwargs["independent_channels"] = independent_channels
|
|
97
|
+
|
|
98
|
+
unet_model = UNetModel(
|
|
99
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
100
|
+
**model_kwargs,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# algorithm model
|
|
104
|
+
algorithm = FCNAlgorithmConfig(
|
|
105
|
+
algorithm_type=algorithm_type,
|
|
106
|
+
algorithm=algorithm,
|
|
107
|
+
loss=loss,
|
|
108
|
+
model=unet_model,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# augmentations
|
|
112
|
+
if use_augmentations:
|
|
113
|
+
transforms: List[Dict[str, Any]] = [
|
|
114
|
+
{
|
|
115
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
116
|
+
},
|
|
117
|
+
{
|
|
118
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
else:
|
|
122
|
+
transforms = []
|
|
123
|
+
|
|
124
|
+
# data model
|
|
125
|
+
data = DataConfig(
|
|
126
|
+
data_type=data_type,
|
|
127
|
+
axes=axes,
|
|
128
|
+
patch_size=patch_size,
|
|
129
|
+
batch_size=batch_size,
|
|
130
|
+
transforms=transforms,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# training model
|
|
134
|
+
training = TrainingConfig(
|
|
135
|
+
num_epochs=num_epochs,
|
|
136
|
+
batch_size=batch_size,
|
|
137
|
+
logger=None if logger == "none" else logger,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# create configuration
|
|
141
|
+
configuration = Configuration(
|
|
142
|
+
experiment_name=experiment_name,
|
|
143
|
+
algorithm_config=algorithm,
|
|
144
|
+
data_config=data,
|
|
145
|
+
training_config=training,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return configuration
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def create_care_configuration(
|
|
152
|
+
experiment_name: str,
|
|
153
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
154
|
+
axes: str,
|
|
155
|
+
patch_size: List[int],
|
|
156
|
+
batch_size: int,
|
|
157
|
+
num_epochs: int,
|
|
158
|
+
use_augmentations: bool = True,
|
|
159
|
+
independent_channels: bool = False,
|
|
160
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
161
|
+
n_channels_in: int = 1,
|
|
162
|
+
n_channels_out: int = -1,
|
|
163
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
164
|
+
model_kwargs: Optional[dict] = None,
|
|
165
|
+
) -> Configuration:
|
|
166
|
+
"""
|
|
167
|
+
Create a configuration for training CARE.
|
|
168
|
+
|
|
169
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
170
|
+
2.
|
|
171
|
+
|
|
172
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
173
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
174
|
+
`axes`.
|
|
175
|
+
|
|
176
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
177
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
178
|
+
|
|
179
|
+
By default, all channels are trained together. To train all channels independently,
|
|
180
|
+
set `independent_channels` to True.
|
|
181
|
+
|
|
182
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
183
|
+
normalization.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
experiment_name : str
|
|
188
|
+
Name of the experiment.
|
|
189
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
190
|
+
Type of the data.
|
|
191
|
+
axes : str
|
|
192
|
+
Axes of the data (e.g. SYX).
|
|
193
|
+
patch_size : List[int]
|
|
194
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
195
|
+
batch_size : int
|
|
196
|
+
Batch size.
|
|
197
|
+
num_epochs : int
|
|
198
|
+
Number of epochs.
|
|
199
|
+
use_augmentations : bool, optional
|
|
200
|
+
Whether to use augmentations, by default True.
|
|
201
|
+
independent_channels : bool, optional
|
|
202
|
+
Whether to train all channels independently, by default False.
|
|
203
|
+
loss : Literal["mae", "mse"], optional
|
|
204
|
+
Loss function to use, by default "mae".
|
|
205
|
+
n_channels_in : int, optional
|
|
206
|
+
Number of channels in, by default 1.
|
|
207
|
+
n_channels_out : int, optional
|
|
208
|
+
Number of channels out, by default -1.
|
|
209
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
210
|
+
Logger to use, by default "none".
|
|
211
|
+
model_kwargs : dict, optional
|
|
212
|
+
UNetModel parameters, by default {}.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
Configuration
|
|
217
|
+
Configuration for training CARE.
|
|
218
|
+
"""
|
|
219
|
+
if n_channels_out == -1:
|
|
220
|
+
n_channels_out = n_channels_in
|
|
221
|
+
|
|
222
|
+
return _create_supervised_configuration(
|
|
223
|
+
algorithm_type="fcn",
|
|
224
|
+
algorithm="care",
|
|
225
|
+
experiment_name=experiment_name,
|
|
226
|
+
data_type=data_type,
|
|
227
|
+
axes=axes,
|
|
228
|
+
patch_size=patch_size,
|
|
229
|
+
batch_size=batch_size,
|
|
230
|
+
num_epochs=num_epochs,
|
|
231
|
+
use_augmentations=use_augmentations,
|
|
232
|
+
independent_channels=independent_channels,
|
|
233
|
+
loss=loss,
|
|
234
|
+
n_channels_in=n_channels_in,
|
|
235
|
+
n_channels_out=n_channels_out,
|
|
236
|
+
logger=logger,
|
|
237
|
+
model_kwargs=model_kwargs,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def create_n2n_configuration(
|
|
242
|
+
experiment_name: str,
|
|
243
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
244
|
+
axes: str,
|
|
245
|
+
patch_size: List[int],
|
|
246
|
+
batch_size: int,
|
|
247
|
+
num_epochs: int,
|
|
248
|
+
use_augmentations: bool = True,
|
|
249
|
+
independent_channels: bool = False,
|
|
250
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
251
|
+
n_channels_in: int = 1,
|
|
252
|
+
n_channels_out: int = -1,
|
|
253
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
254
|
+
model_kwargs: Optional[dict] = None,
|
|
255
|
+
) -> Configuration:
|
|
256
|
+
"""
|
|
257
|
+
Create a configuration for training Noise2Noise.
|
|
258
|
+
|
|
259
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
260
|
+
2.
|
|
261
|
+
|
|
262
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
263
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
264
|
+
`axes`.
|
|
265
|
+
|
|
266
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
267
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
268
|
+
|
|
269
|
+
By default, all channels are trained together. To train all channels independently,
|
|
270
|
+
set `independent_channels` to True.
|
|
271
|
+
|
|
272
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
273
|
+
normalization.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
experiment_name : str
|
|
278
|
+
Name of the experiment.
|
|
279
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
280
|
+
Type of the data.
|
|
281
|
+
axes : str
|
|
282
|
+
Axes of the data (e.g. SYX).
|
|
283
|
+
patch_size : List[int]
|
|
284
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
285
|
+
batch_size : int
|
|
286
|
+
Batch size.
|
|
287
|
+
num_epochs : int
|
|
288
|
+
Number of epochs.
|
|
289
|
+
use_augmentations : bool, optional
|
|
290
|
+
Whether to use augmentations, by default True.
|
|
291
|
+
independent_channels : bool, optional
|
|
292
|
+
Whether to train all channels independently, by default False.
|
|
293
|
+
loss : Literal["mae", "mse"], optional
|
|
294
|
+
Loss function to use, by default "mae".
|
|
295
|
+
n_channels_in : int, optional
|
|
296
|
+
Number of channels in, by default 1.
|
|
297
|
+
n_channels_out : int, optional
|
|
298
|
+
Number of channels out, by default -1.
|
|
299
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
300
|
+
Logger to use, by default "none".
|
|
301
|
+
model_kwargs : dict, optional
|
|
302
|
+
UNetModel parameters, by default {}.
|
|
303
|
+
|
|
304
|
+
Returns
|
|
305
|
+
-------
|
|
306
|
+
Configuration
|
|
307
|
+
Configuration for training Noise2Noise.
|
|
308
|
+
"""
|
|
309
|
+
if n_channels_out == -1:
|
|
310
|
+
n_channels_out = n_channels_in
|
|
311
|
+
|
|
312
|
+
return _create_supervised_configuration(
|
|
313
|
+
algorithm_type="fcn",
|
|
314
|
+
algorithm="n2n",
|
|
315
|
+
experiment_name=experiment_name,
|
|
316
|
+
data_type=data_type,
|
|
317
|
+
axes=axes,
|
|
318
|
+
patch_size=patch_size,
|
|
319
|
+
batch_size=batch_size,
|
|
320
|
+
num_epochs=num_epochs,
|
|
321
|
+
use_augmentations=use_augmentations,
|
|
322
|
+
independent_channels=independent_channels,
|
|
323
|
+
loss=loss,
|
|
324
|
+
n_channels_in=n_channels_in,
|
|
325
|
+
n_channels_out=n_channels_out,
|
|
326
|
+
logger=logger,
|
|
327
|
+
model_kwargs=model_kwargs,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def create_n2v_configuration(
|
|
332
|
+
experiment_name: str,
|
|
333
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
334
|
+
axes: str,
|
|
335
|
+
patch_size: List[int],
|
|
336
|
+
batch_size: int,
|
|
337
|
+
num_epochs: int,
|
|
338
|
+
use_augmentations: bool = True,
|
|
339
|
+
independent_channels: bool = True,
|
|
340
|
+
use_n2v2: bool = False,
|
|
341
|
+
n_channels: int = 1,
|
|
342
|
+
roi_size: int = 11,
|
|
343
|
+
masked_pixel_percentage: float = 0.2,
|
|
344
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
345
|
+
struct_n2v_span: int = 5,
|
|
346
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
347
|
+
model_kwargs: Optional[dict] = None,
|
|
348
|
+
) -> Configuration:
|
|
349
|
+
"""
|
|
350
|
+
Create a configuration for training Noise2Void.
|
|
351
|
+
|
|
352
|
+
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
353
|
+
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
354
|
+
(structN2V) parameters, or set `use_n2v2` to True (N2V2).
|
|
355
|
+
|
|
356
|
+
N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
|
|
357
|
+
connections, thus removing checkboard artefacts. StructN2V is used when vertical
|
|
358
|
+
or horizontal correlations are present in the noise; it applies an additional mask
|
|
359
|
+
to the manipulated pixel neighbors.
|
|
360
|
+
|
|
361
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
362
|
+
2.
|
|
363
|
+
|
|
364
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
365
|
+
channels.
|
|
366
|
+
|
|
367
|
+
By default, all channels are trained independently. To train all channels together,
|
|
368
|
+
set `independent_channels` to False.
|
|
369
|
+
|
|
370
|
+
By setting `use_augmentations` to False, the only transformations applied will be
|
|
371
|
+
normalization and N2V manipulation.
|
|
372
|
+
|
|
373
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
374
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
375
|
+
pixels per patch will be manipulated.
|
|
376
|
+
|
|
377
|
+
The parameters of the UNet can be specified in the `model_kwargs` (passed as a
|
|
378
|
+
parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
|
|
379
|
+
corresponding parameters passed in `model_kwargs`.
|
|
380
|
+
|
|
381
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
382
|
+
will be applied to each manipulated pixel.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
experiment_name : str
|
|
387
|
+
Name of the experiment.
|
|
388
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
389
|
+
Type of the data.
|
|
390
|
+
axes : str
|
|
391
|
+
Axes of the data (e.g. SYX).
|
|
392
|
+
patch_size : List[int]
|
|
393
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
394
|
+
batch_size : int
|
|
395
|
+
Batch size.
|
|
396
|
+
num_epochs : int
|
|
397
|
+
Number of epochs.
|
|
398
|
+
use_augmentations : bool, optional
|
|
399
|
+
Whether to use augmentations, by default True.
|
|
400
|
+
independent_channels : bool, optional
|
|
401
|
+
Whether to train all channels together, by default True.
|
|
402
|
+
use_n2v2 : bool, optional
|
|
403
|
+
Whether to use N2V2, by default False.
|
|
404
|
+
n_channels : int, optional
|
|
405
|
+
Number of channels (in and out), by default 1.
|
|
406
|
+
roi_size : int, optional
|
|
407
|
+
N2V pixel manipulation area, by default 11.
|
|
408
|
+
masked_pixel_percentage : float, optional
|
|
409
|
+
Percentage of pixels masked in each patch, by default 0.2.
|
|
410
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
411
|
+
Axis along which to apply structN2V mask, by default "none".
|
|
412
|
+
struct_n2v_span : int, optional
|
|
413
|
+
Span of the structN2V mask, by default 5.
|
|
414
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
415
|
+
Logger to use, by default "none".
|
|
416
|
+
model_kwargs : dict, optional
|
|
417
|
+
UNetModel parameters, by default {}.
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
Configuration
|
|
422
|
+
Configuration for training N2V.
|
|
423
|
+
|
|
424
|
+
Examples
|
|
425
|
+
--------
|
|
426
|
+
Minimum example:
|
|
427
|
+
>>> config = create_n2v_configuration(
|
|
428
|
+
... experiment_name="n2v_experiment",
|
|
429
|
+
... data_type="array",
|
|
430
|
+
... axes="YX",
|
|
431
|
+
... patch_size=[64, 64],
|
|
432
|
+
... batch_size=32,
|
|
433
|
+
... num_epochs=100
|
|
434
|
+
... )
|
|
435
|
+
|
|
436
|
+
To use N2V2, simply pass the `use_n2v2` parameter:
|
|
437
|
+
>>> config = create_n2v_configuration(
|
|
438
|
+
... experiment_name="n2v2_experiment",
|
|
439
|
+
... data_type="tiff",
|
|
440
|
+
... axes="YX",
|
|
441
|
+
... patch_size=[64, 64],
|
|
442
|
+
... batch_size=32,
|
|
443
|
+
... num_epochs=100,
|
|
444
|
+
... use_n2v2=True
|
|
445
|
+
... )
|
|
446
|
+
|
|
447
|
+
For structN2V, there are two parameters to set, `struct_n2v_axis` and
|
|
448
|
+
`struct_n2v_span`:
|
|
449
|
+
>>> config = create_n2v_configuration(
|
|
450
|
+
... experiment_name="structn2v_experiment",
|
|
451
|
+
... data_type="tiff",
|
|
452
|
+
... axes="YX",
|
|
453
|
+
... patch_size=[64, 64],
|
|
454
|
+
... batch_size=32,
|
|
455
|
+
... num_epochs=100,
|
|
456
|
+
... struct_n2v_axis="horizontal",
|
|
457
|
+
... struct_n2v_span=7
|
|
458
|
+
... )
|
|
459
|
+
|
|
460
|
+
If you are training multiple channels independently, then you need to specify the
|
|
461
|
+
number of channels:
|
|
462
|
+
>>> config = create_n2v_configuration(
|
|
463
|
+
... experiment_name="n2v_experiment",
|
|
464
|
+
... data_type="array",
|
|
465
|
+
... axes="YXC",
|
|
466
|
+
... patch_size=[64, 64],
|
|
467
|
+
... batch_size=32,
|
|
468
|
+
... num_epochs=100,
|
|
469
|
+
... n_channels=3
|
|
470
|
+
... )
|
|
471
|
+
|
|
472
|
+
If instead you want to train multiple channels together, you need to turn off the
|
|
473
|
+
`independent_channels` parameter:
|
|
474
|
+
>>> config = create_n2v_configuration(
|
|
475
|
+
... experiment_name="n2v_experiment",
|
|
476
|
+
... data_type="array",
|
|
477
|
+
... axes="YXC",
|
|
478
|
+
... patch_size=[64, 64],
|
|
479
|
+
... batch_size=32,
|
|
480
|
+
... num_epochs=100,
|
|
481
|
+
... independent_channels=False,
|
|
482
|
+
... n_channels=3
|
|
483
|
+
... )
|
|
484
|
+
|
|
485
|
+
To turn off the augmentations, except normalization and N2V manipulation, use the
|
|
486
|
+
relevant keyword argument:
|
|
487
|
+
>>> config = create_n2v_configuration(
|
|
488
|
+
... experiment_name="n2v_experiment",
|
|
489
|
+
... data_type="array",
|
|
490
|
+
... axes="YX",
|
|
491
|
+
... patch_size=[64, 64],
|
|
492
|
+
... batch_size=32,
|
|
493
|
+
... num_epochs=100,
|
|
494
|
+
... use_augmentations=False
|
|
495
|
+
... )
|
|
496
|
+
"""
|
|
497
|
+
# if there are channels, we need to specify their number
|
|
498
|
+
if "C" in axes and n_channels == 1:
|
|
499
|
+
raise ValueError(
|
|
500
|
+
f"Number of channels must be specified when using channels "
|
|
501
|
+
f"(got {n_channels} channel)."
|
|
502
|
+
)
|
|
503
|
+
elif "C" not in axes and n_channels > 1:
|
|
504
|
+
raise ValueError(
|
|
505
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
506
|
+
f"(got {n_channels} channel)."
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# model
|
|
510
|
+
if model_kwargs is None:
|
|
511
|
+
model_kwargs = {}
|
|
512
|
+
model_kwargs["n2v2"] = use_n2v2
|
|
513
|
+
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
514
|
+
model_kwargs["in_channels"] = n_channels
|
|
515
|
+
model_kwargs["num_classes"] = n_channels
|
|
516
|
+
model_kwargs["independent_channels"] = independent_channels
|
|
517
|
+
|
|
518
|
+
unet_model = UNetModel(
|
|
519
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
520
|
+
**model_kwargs,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
# algorithm model
|
|
524
|
+
algorithm = FCNAlgorithmConfig(
|
|
525
|
+
algorithm_type="fcn",
|
|
526
|
+
algorithm=SupportedAlgorithm.N2V.value,
|
|
527
|
+
loss=SupportedLoss.N2V.value,
|
|
528
|
+
model=unet_model,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# augmentations
|
|
532
|
+
if use_augmentations:
|
|
533
|
+
transforms: List[Dict[str, Any]] = [
|
|
534
|
+
{
|
|
535
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
536
|
+
},
|
|
537
|
+
{
|
|
538
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
539
|
+
},
|
|
540
|
+
]
|
|
541
|
+
else:
|
|
542
|
+
transforms = []
|
|
543
|
+
|
|
544
|
+
# n2v2 and structn2v
|
|
545
|
+
nv2_transform = {
|
|
546
|
+
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
547
|
+
"strategy": (
|
|
548
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
549
|
+
if use_n2v2
|
|
550
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
551
|
+
),
|
|
552
|
+
"roi_size": roi_size,
|
|
553
|
+
"masked_pixel_percentage": masked_pixel_percentage,
|
|
554
|
+
"struct_mask_axis": struct_n2v_axis,
|
|
555
|
+
"struct_mask_span": struct_n2v_span,
|
|
556
|
+
}
|
|
557
|
+
transforms.append(nv2_transform)
|
|
558
|
+
|
|
559
|
+
# data model
|
|
560
|
+
data = DataConfig(
|
|
561
|
+
data_type=data_type,
|
|
562
|
+
axes=axes,
|
|
563
|
+
patch_size=patch_size,
|
|
564
|
+
batch_size=batch_size,
|
|
565
|
+
transforms=transforms,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# training model
|
|
569
|
+
training = TrainingConfig(
|
|
570
|
+
num_epochs=num_epochs,
|
|
571
|
+
batch_size=batch_size,
|
|
572
|
+
logger=None if logger == "none" else logger,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# create configuration
|
|
576
|
+
configuration = Configuration(
|
|
577
|
+
experiment_name=experiment_name,
|
|
578
|
+
algorithm_config=algorithm,
|
|
579
|
+
data_config=data,
|
|
580
|
+
training_config=training,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
return configuration
|