careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
|
|
|
13
13
|
)
|
|
14
14
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.config import
|
|
17
|
-
Configuration,
|
|
18
|
-
load_configuration,
|
|
19
|
-
)
|
|
16
|
+
from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
|
|
20
17
|
from careamics.config.support import (
|
|
21
18
|
SupportedAlgorithm,
|
|
22
19
|
SupportedArchitecture,
|
|
@@ -25,7 +22,7 @@ from careamics.config.support import (
|
|
|
25
22
|
)
|
|
26
23
|
from careamics.dataset.dataset_utils import reshape_array
|
|
27
24
|
from careamics.lightning import (
|
|
28
|
-
|
|
25
|
+
FCNModule,
|
|
29
26
|
HyperParametersCallback,
|
|
30
27
|
PredictDataModule,
|
|
31
28
|
ProgressBarCallback,
|
|
@@ -51,8 +48,6 @@ class CAREamist:
|
|
|
51
48
|
work_dir : str, optional
|
|
52
49
|
Path to working directory in which to save checkpoints and logs,
|
|
53
50
|
by default None.
|
|
54
|
-
experiment_name : str, by default "CAREamics"
|
|
55
|
-
Experiment name used for checkpoints.
|
|
56
51
|
callbacks : list of Callback, optional
|
|
57
52
|
List of callbacks to use during training and prediction, by default None.
|
|
58
53
|
|
|
@@ -78,8 +73,7 @@ class CAREamist:
|
|
|
78
73
|
def __init__( # numpydoc ignore=GL08
|
|
79
74
|
self,
|
|
80
75
|
source: Union[Path, str],
|
|
81
|
-
work_dir: Optional[str] = None,
|
|
82
|
-
experiment_name: str = "CAREamics",
|
|
76
|
+
work_dir: Optional[Union[Path, str]] = None,
|
|
83
77
|
callbacks: Optional[list[Callback]] = None,
|
|
84
78
|
) -> None: ...
|
|
85
79
|
|
|
@@ -87,8 +81,7 @@ class CAREamist:
|
|
|
87
81
|
def __init__( # numpydoc ignore=GL08
|
|
88
82
|
self,
|
|
89
83
|
source: Configuration,
|
|
90
|
-
work_dir: Optional[str] = None,
|
|
91
|
-
experiment_name: str = "CAREamics",
|
|
84
|
+
work_dir: Optional[Union[Path, str]] = None,
|
|
92
85
|
callbacks: Optional[list[Callback]] = None,
|
|
93
86
|
) -> None: ...
|
|
94
87
|
|
|
@@ -96,7 +89,6 @@ class CAREamist:
|
|
|
96
89
|
self,
|
|
97
90
|
source: Union[Path, str, Configuration],
|
|
98
91
|
work_dir: Optional[Union[Path, str]] = None,
|
|
99
|
-
experiment_name: str = "CAREamics",
|
|
100
92
|
callbacks: Optional[list[Callback]] = None,
|
|
101
93
|
) -> None:
|
|
102
94
|
"""
|
|
@@ -109,18 +101,13 @@ class CAREamist:
|
|
|
109
101
|
|
|
110
102
|
If no working directory is provided, the current working directory is used.
|
|
111
103
|
|
|
112
|
-
If `source` is a checkpoint, then `experiment_name` is used to name the
|
|
113
|
-
checkpoint, and is recorded in the configuration.
|
|
114
|
-
|
|
115
104
|
Parameters
|
|
116
105
|
----------
|
|
117
106
|
source : pathlib.Path or str or CAREamics Configuration
|
|
118
107
|
Path to a configuration file or a trained model.
|
|
119
|
-
work_dir : str, optional
|
|
108
|
+
work_dir : str or pathlib.Path, optional
|
|
120
109
|
Path to working directory in which to save checkpoints and logs,
|
|
121
110
|
by default None.
|
|
122
|
-
experiment_name : str, optional
|
|
123
|
-
Experiment name used for checkpoints, by default "CAREamics".
|
|
124
111
|
callbacks : list of Callback, optional
|
|
125
112
|
List of callbacks to use during training and prediction, by default None.
|
|
126
113
|
|
|
@@ -148,9 +135,12 @@ class CAREamist:
|
|
|
148
135
|
self.cfg = source
|
|
149
136
|
|
|
150
137
|
# instantiate model
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
|
|
138
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
139
|
+
self.model = FCNModule(
|
|
140
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
raise NotImplementedError("Architecture not supported.")
|
|
154
144
|
|
|
155
145
|
# path to configuration file or model
|
|
156
146
|
else:
|
|
@@ -164,9 +154,12 @@ class CAREamist:
|
|
|
164
154
|
self.cfg = load_configuration(source)
|
|
165
155
|
|
|
166
156
|
# instantiate model
|
|
167
|
-
self.
|
|
168
|
-
|
|
169
|
-
|
|
157
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
158
|
+
self.model = FCNModule(
|
|
159
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
160
|
+
) # type: ignore
|
|
161
|
+
else:
|
|
162
|
+
raise NotImplementedError("Architecture not supported.")
|
|
170
163
|
|
|
171
164
|
# attempt loading a pre-trained model
|
|
172
165
|
else:
|
|
@@ -192,6 +185,13 @@ class CAREamist:
|
|
|
192
185
|
# instantiate trainer
|
|
193
186
|
self.trainer = Trainer(
|
|
194
187
|
max_epochs=self.cfg.training_config.num_epochs,
|
|
188
|
+
precision=self.cfg.training_config.precision,
|
|
189
|
+
max_steps=self.cfg.training_config.max_steps,
|
|
190
|
+
check_val_every_n_epoch=self.cfg.training_config.check_val_every_n_epoch,
|
|
191
|
+
enable_progress_bar=self.cfg.training_config.enable_progress_bar,
|
|
192
|
+
accumulate_grad_batches=self.cfg.training_config.accumulate_grad_batches,
|
|
193
|
+
gradient_clip_val=self.cfg.training_config.gradient_clip_val,
|
|
194
|
+
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
|
|
195
195
|
callbacks=self.callbacks,
|
|
196
196
|
default_root_dir=self.work_dir,
|
|
197
197
|
logger=self.experiment_logger,
|
|
@@ -247,6 +247,12 @@ class CAREamist:
|
|
|
247
247
|
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
248
248
|
)
|
|
249
249
|
|
|
250
|
+
def stop_training(self) -> None:
|
|
251
|
+
"""Stop the training loop."""
|
|
252
|
+
# raise stop training flag
|
|
253
|
+
self.trainer.should_stop = True
|
|
254
|
+
self.trainer.limit_val_batches = 0 # skip validation
|
|
255
|
+
|
|
250
256
|
# TODO: is there are more elegant way than calling train again after _train_on_paths
|
|
251
257
|
def train(
|
|
252
258
|
self,
|
|
@@ -393,9 +399,14 @@ class CAREamist:
|
|
|
393
399
|
datamodule : TrainDataModule
|
|
394
400
|
Datamodule to train on.
|
|
395
401
|
"""
|
|
396
|
-
#
|
|
402
|
+
# register datamodule
|
|
397
403
|
self.train_datamodule = datamodule
|
|
398
404
|
|
|
405
|
+
# set defaults (in case `stop_training` was called before)
|
|
406
|
+
self.trainer.should_stop = False
|
|
407
|
+
self.trainer.limit_val_batches = 1.0 # 100%
|
|
408
|
+
|
|
409
|
+
# train
|
|
399
410
|
self.trainer.fit(self.model, datamodule=datamodule)
|
|
400
411
|
|
|
401
412
|
def _train_on_array(
|
|
@@ -511,7 +522,7 @@ class CAREamist:
|
|
|
511
522
|
tile_overlap: tuple[int, ...] = (48, 48),
|
|
512
523
|
axes: Optional[str] = None,
|
|
513
524
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
514
|
-
tta_transforms: bool =
|
|
525
|
+
tta_transforms: bool = False,
|
|
515
526
|
dataloader_params: Optional[dict] = None,
|
|
516
527
|
read_source_func: Optional[Callable] = None,
|
|
517
528
|
extension_filter: str = "",
|
|
@@ -527,7 +538,7 @@ class CAREamist:
|
|
|
527
538
|
tile_overlap: tuple[int, ...] = (48, 48),
|
|
528
539
|
axes: Optional[str] = None,
|
|
529
540
|
data_type: Optional[Literal["array"]] = None,
|
|
530
|
-
tta_transforms: bool =
|
|
541
|
+
tta_transforms: bool = False,
|
|
531
542
|
dataloader_params: Optional[dict] = None,
|
|
532
543
|
) -> Union[list[NDArray], NDArray]: ...
|
|
533
544
|
|
|
@@ -540,7 +551,7 @@ class CAREamist:
|
|
|
540
551
|
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
541
552
|
axes: Optional[str] = None,
|
|
542
553
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
543
|
-
tta_transforms: bool =
|
|
554
|
+
tta_transforms: bool = False,
|
|
544
555
|
dataloader_params: Optional[dict] = None,
|
|
545
556
|
read_source_func: Optional[Callable] = None,
|
|
546
557
|
extension_filter: str = "",
|
careamics/cli/conf.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""Configuration building convenience functions for the CAREamics CLI."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Tuple
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
import typer
|
|
10
|
+
import yaml
|
|
11
|
+
from typing_extensions import Annotated
|
|
12
|
+
|
|
13
|
+
from ..config import (
|
|
14
|
+
Configuration,
|
|
15
|
+
create_care_configuration,
|
|
16
|
+
create_n2n_configuration,
|
|
17
|
+
create_n2v_configuration,
|
|
18
|
+
save_configuration,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
WORK_DIR = Path.cwd()
|
|
22
|
+
|
|
23
|
+
app = typer.Typer()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _config_builder_exit(ctx: typer.Context, config: Configuration) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Function to be called at the end of a CLI configuration builder.
|
|
29
|
+
|
|
30
|
+
Saves the `config` object and performs other functionality depending on the command
|
|
31
|
+
context.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
ctx : typer.Context
|
|
36
|
+
Typer Context.
|
|
37
|
+
config : Configuration
|
|
38
|
+
CAREamics configuration.
|
|
39
|
+
"""
|
|
40
|
+
conf_path = (ctx.obj.dir / ctx.obj.name).with_suffix(".yaml")
|
|
41
|
+
save_configuration(config, conf_path)
|
|
42
|
+
if ctx.obj.print:
|
|
43
|
+
print(yaml.dump(config.model_dump(), indent=2))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ConfOptions:
|
|
48
|
+
"""Data class for containing CLI `conf` command option values."""
|
|
49
|
+
|
|
50
|
+
dir: Path
|
|
51
|
+
name: str
|
|
52
|
+
force: bool
|
|
53
|
+
print: bool
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@app.callback()
|
|
57
|
+
def conf_options( # numpydoc ignore=PR01
|
|
58
|
+
ctx: typer.Context,
|
|
59
|
+
dir: Annotated[
|
|
60
|
+
Path,
|
|
61
|
+
typer.Option(
|
|
62
|
+
"--dir", "-d", exists=True, help="Directory to save the config file to."
|
|
63
|
+
),
|
|
64
|
+
] = WORK_DIR,
|
|
65
|
+
name: Annotated[
|
|
66
|
+
str, typer.Option("--name", "-n", help="The config file name.")
|
|
67
|
+
] = "config",
|
|
68
|
+
force: Annotated[
|
|
69
|
+
bool,
|
|
70
|
+
typer.Option(
|
|
71
|
+
"--force", "-f", help="Whether to overwrite existing config files."
|
|
72
|
+
),
|
|
73
|
+
] = False,
|
|
74
|
+
print: Annotated[
|
|
75
|
+
bool,
|
|
76
|
+
typer.Option(
|
|
77
|
+
"--print",
|
|
78
|
+
"-p",
|
|
79
|
+
help="Whether to print the config file to the console.",
|
|
80
|
+
),
|
|
81
|
+
] = False,
|
|
82
|
+
):
|
|
83
|
+
"""Build and save CAREamics configuration files."""
|
|
84
|
+
# Callback is called still on --help command
|
|
85
|
+
# If a config exists it will complain that you need to use the -f flag
|
|
86
|
+
if "--help" in sys.argv:
|
|
87
|
+
return
|
|
88
|
+
conf_path = (dir / name).with_suffix(".yaml")
|
|
89
|
+
if conf_path.exists() and not force:
|
|
90
|
+
raise FileExistsError(f"To overwrite '{conf_path}' use flag --force/-f.")
|
|
91
|
+
|
|
92
|
+
ctx.obj = ConfOptions(dir, name, force, print)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
|
|
96
|
+
"""
|
|
97
|
+
Callback for --patch-size option.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
value : (int, int, int)
|
|
102
|
+
Patch size value.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
(int, int, int) | (int, int)
|
|
107
|
+
If the last element in `value` is -1 the tuple is reduced to the first two
|
|
108
|
+
values.
|
|
109
|
+
"""
|
|
110
|
+
if value[2] == -1:
|
|
111
|
+
return value[:2]
|
|
112
|
+
return value
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# TODO: Need to decide how to parse model kwargs
|
|
116
|
+
# - Could be json style string to be loaded as dict e.g. {"depth": 3}
|
|
117
|
+
# - Cons: Annoying to type, easily have syntax errors
|
|
118
|
+
# - Could parse all unknown options as model kwargs
|
|
119
|
+
# - Cons: There could be argument name clashes
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@app.command()
|
|
123
|
+
def care( # numpydoc ignore=PR01
|
|
124
|
+
ctx: typer.Context,
|
|
125
|
+
experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
|
|
126
|
+
axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
|
|
127
|
+
patch_size: Annotated[
|
|
128
|
+
click.Tuple,
|
|
129
|
+
typer.Option(
|
|
130
|
+
help=(
|
|
131
|
+
"Size of the patches along the spatial dimensions (if the data "
|
|
132
|
+
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
133
|
+
),
|
|
134
|
+
click_type=click.Tuple([int, int, int]),
|
|
135
|
+
callback=patch_size_callback,
|
|
136
|
+
),
|
|
137
|
+
],
|
|
138
|
+
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
139
|
+
num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
|
|
140
|
+
data_type: Annotated[
|
|
141
|
+
click.Choice,
|
|
142
|
+
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
|
|
143
|
+
] = "tiff",
|
|
144
|
+
use_augmentations: Annotated[
|
|
145
|
+
bool, typer.Option(help="Whether to use augmentations.")
|
|
146
|
+
] = True,
|
|
147
|
+
independent_channels: Annotated[
|
|
148
|
+
bool, typer.Option(help="Whether to train all channels independently.")
|
|
149
|
+
] = False,
|
|
150
|
+
loss: Annotated[
|
|
151
|
+
click.Choice,
|
|
152
|
+
typer.Option(
|
|
153
|
+
click_type=click.Choice(["mae", "mse"]),
|
|
154
|
+
help="Loss function to use.",
|
|
155
|
+
),
|
|
156
|
+
] = "mae",
|
|
157
|
+
n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
|
|
158
|
+
n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
|
|
159
|
+
logger: Annotated[
|
|
160
|
+
click.Choice,
|
|
161
|
+
typer.Option(
|
|
162
|
+
click_type=click.Choice(["wandb", "tensorboard", "none"]),
|
|
163
|
+
help="Logger to use.",
|
|
164
|
+
),
|
|
165
|
+
] = "none",
|
|
166
|
+
# TODO: How to address model kwargs
|
|
167
|
+
) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Create a configuration for training CARE.
|
|
170
|
+
|
|
171
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
172
|
+
2.
|
|
173
|
+
|
|
174
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
175
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
176
|
+
`axes`.
|
|
177
|
+
|
|
178
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
179
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
180
|
+
|
|
181
|
+
By default, all channels are trained together. To train all channels independently,
|
|
182
|
+
set `independent_channels` to True.
|
|
183
|
+
|
|
184
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
185
|
+
normalization.
|
|
186
|
+
"""
|
|
187
|
+
config = create_care_configuration(
|
|
188
|
+
experiment_name=experiment_name,
|
|
189
|
+
data_type=data_type,
|
|
190
|
+
axes=axes,
|
|
191
|
+
patch_size=patch_size,
|
|
192
|
+
batch_size=batch_size,
|
|
193
|
+
num_epochs=num_epochs,
|
|
194
|
+
# TODO: fix choosing augmentations
|
|
195
|
+
augmentations=None if use_augmentations else [],
|
|
196
|
+
independent_channels=independent_channels,
|
|
197
|
+
loss=loss,
|
|
198
|
+
n_channels_in=n_channels_in,
|
|
199
|
+
n_channels_out=n_channels_out,
|
|
200
|
+
logger=logger,
|
|
201
|
+
)
|
|
202
|
+
_config_builder_exit(ctx, config)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@app.command()
|
|
206
|
+
def n2n( # numpydoc ignore=PR01
|
|
207
|
+
ctx: typer.Context,
|
|
208
|
+
experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
|
|
209
|
+
axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
|
|
210
|
+
patch_size: Annotated[
|
|
211
|
+
click.Tuple,
|
|
212
|
+
typer.Option(
|
|
213
|
+
help=(
|
|
214
|
+
"Size of the patches along the spatial dimensions (if the data "
|
|
215
|
+
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
216
|
+
),
|
|
217
|
+
click_type=click.Tuple([int, int, int]),
|
|
218
|
+
callback=patch_size_callback,
|
|
219
|
+
),
|
|
220
|
+
],
|
|
221
|
+
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
222
|
+
num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
|
|
223
|
+
data_type: Annotated[
|
|
224
|
+
click.Choice,
|
|
225
|
+
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
|
|
226
|
+
] = "tiff",
|
|
227
|
+
use_augmentations: Annotated[
|
|
228
|
+
bool, typer.Option(help="Whether to use augmentations.")
|
|
229
|
+
] = True,
|
|
230
|
+
independent_channels: Annotated[
|
|
231
|
+
bool, typer.Option(help="Whether to train all channels independently.")
|
|
232
|
+
] = False,
|
|
233
|
+
loss: Annotated[
|
|
234
|
+
click.Choice,
|
|
235
|
+
typer.Option(
|
|
236
|
+
click_type=click.Choice(["mae", "mse"]),
|
|
237
|
+
help="Loss function to use.",
|
|
238
|
+
),
|
|
239
|
+
] = "mae",
|
|
240
|
+
n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
|
|
241
|
+
n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
|
|
242
|
+
logger: Annotated[
|
|
243
|
+
click.Choice,
|
|
244
|
+
typer.Option(
|
|
245
|
+
click_type=click.Choice(["wandb", "tensorboard", "none"]),
|
|
246
|
+
help="Logger to use.",
|
|
247
|
+
),
|
|
248
|
+
] = "none",
|
|
249
|
+
# TODO: How to address model kwargs
|
|
250
|
+
) -> None:
|
|
251
|
+
"""
|
|
252
|
+
Create a configuration for training Noise2Noise.
|
|
253
|
+
|
|
254
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
255
|
+
2.
|
|
256
|
+
|
|
257
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
258
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
259
|
+
`axes`.
|
|
260
|
+
|
|
261
|
+
By default, all channels are trained together. To train all channels independently,
|
|
262
|
+
set `independent_channels` to True.
|
|
263
|
+
|
|
264
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
265
|
+
normalization.
|
|
266
|
+
"""
|
|
267
|
+
config = create_n2n_configuration(
|
|
268
|
+
experiment_name=experiment_name,
|
|
269
|
+
data_type=data_type,
|
|
270
|
+
axes=axes,
|
|
271
|
+
patch_size=patch_size,
|
|
272
|
+
batch_size=batch_size,
|
|
273
|
+
num_epochs=num_epochs,
|
|
274
|
+
# TODO: fix choosing augmentations
|
|
275
|
+
augmentations=None if use_augmentations else [],
|
|
276
|
+
independent_channels=independent_channels,
|
|
277
|
+
loss=loss,
|
|
278
|
+
n_channels_in=n_channels_in,
|
|
279
|
+
n_channels_out=n_channels_out,
|
|
280
|
+
logger=logger,
|
|
281
|
+
)
|
|
282
|
+
_config_builder_exit(ctx, config)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@app.command()
|
|
286
|
+
def n2v( # numpydoc ignore=PR01
|
|
287
|
+
ctx: typer.Context,
|
|
288
|
+
experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
|
|
289
|
+
axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
|
|
290
|
+
patch_size: Annotated[
|
|
291
|
+
click.Tuple,
|
|
292
|
+
typer.Option(
|
|
293
|
+
help=(
|
|
294
|
+
"Size of the patches along the spatial dimensions (if the data "
|
|
295
|
+
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
296
|
+
),
|
|
297
|
+
click_type=click.Tuple([int, int, int]),
|
|
298
|
+
callback=patch_size_callback,
|
|
299
|
+
),
|
|
300
|
+
],
|
|
301
|
+
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
302
|
+
num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
|
|
303
|
+
data_type: Annotated[
|
|
304
|
+
click.Choice,
|
|
305
|
+
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
|
|
306
|
+
] = "tiff",
|
|
307
|
+
use_augmentations: Annotated[
|
|
308
|
+
bool, typer.Option(help="Whether to use augmentations.")
|
|
309
|
+
] = True,
|
|
310
|
+
independent_channels: Annotated[
|
|
311
|
+
bool, typer.Option(help="Whether to train all channels independently.")
|
|
312
|
+
] = True,
|
|
313
|
+
use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
|
|
314
|
+
n_channels: Annotated[
|
|
315
|
+
int, typer.Option(help="Number of channels (in and out)")
|
|
316
|
+
] = 1,
|
|
317
|
+
roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
|
|
318
|
+
masked_pixel_percentage: Annotated[
|
|
319
|
+
float, typer.Option(help="Percentage of pixels masked in each patch.")
|
|
320
|
+
] = 0.2,
|
|
321
|
+
struct_n2v_axis: Annotated[
|
|
322
|
+
click.Choice,
|
|
323
|
+
typer.Option(click_type=click.Choice(["horizontal", "vertical", "none"])),
|
|
324
|
+
] = "none",
|
|
325
|
+
struct_n2v_span: Annotated[
|
|
326
|
+
int, typer.Option(help="Span of the structN2V mask.")
|
|
327
|
+
] = 5,
|
|
328
|
+
logger: Annotated[
|
|
329
|
+
click.Choice,
|
|
330
|
+
typer.Option(
|
|
331
|
+
click_type=click.Choice(["wandb", "tensorboard", "none"]),
|
|
332
|
+
help="Logger to use.",
|
|
333
|
+
),
|
|
334
|
+
] = "none",
|
|
335
|
+
# TODO: How to address model kwargs
|
|
336
|
+
) -> None:
|
|
337
|
+
"""
|
|
338
|
+
Create a configuration for training Noise2Void.
|
|
339
|
+
|
|
340
|
+
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
341
|
+
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
342
|
+
(structN2V) parameters, or set `use_n2v2` to True (N2V2).
|
|
343
|
+
|
|
344
|
+
N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
|
|
345
|
+
connections, thus removing checkboard artefacts. StructN2V is used when vertical
|
|
346
|
+
or horizontal correlations are present in the noise; it applies an additional mask
|
|
347
|
+
to the manipulated pixel neighbors.
|
|
348
|
+
|
|
349
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
350
|
+
2.
|
|
351
|
+
|
|
352
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
353
|
+
channels.
|
|
354
|
+
|
|
355
|
+
By default, all channels are trained independently. To train all channels together,
|
|
356
|
+
set `independent_channels` to False.
|
|
357
|
+
|
|
358
|
+
By setting `use_augmentations` to False, the only transformations applied will be
|
|
359
|
+
normalization and N2V manipulation.
|
|
360
|
+
|
|
361
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
362
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
363
|
+
pixels per patch will be manipulated.
|
|
364
|
+
|
|
365
|
+
The parameters of the UNet can be specified in the `model_kwargs` (passed as a
|
|
366
|
+
parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
|
|
367
|
+
corresponding parameters passed in `model_kwargs`.
|
|
368
|
+
|
|
369
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
370
|
+
will be applied to each manipulated pixel.
|
|
371
|
+
"""
|
|
372
|
+
config = create_n2v_configuration(
|
|
373
|
+
experiment_name=experiment_name,
|
|
374
|
+
data_type=data_type,
|
|
375
|
+
axes=axes,
|
|
376
|
+
patch_size=patch_size,
|
|
377
|
+
batch_size=batch_size,
|
|
378
|
+
num_epochs=num_epochs,
|
|
379
|
+
# TODO: fix choosing augmentations
|
|
380
|
+
augmentations=None if use_augmentations else [],
|
|
381
|
+
independent_channels=independent_channels,
|
|
382
|
+
use_n2v2=use_n2v2,
|
|
383
|
+
n_channels=n_channels,
|
|
384
|
+
roi_size=roi_size,
|
|
385
|
+
masked_pixel_percentage=masked_pixel_percentage,
|
|
386
|
+
struct_n2v_axis=struct_n2v_axis,
|
|
387
|
+
struct_n2v_span=struct_n2v_span,
|
|
388
|
+
logger=logger,
|
|
389
|
+
# TODO: Model kwargs
|
|
390
|
+
)
|
|
391
|
+
_config_builder_exit(ctx, config)
|
careamics/cli/main.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for CLI functionality and entrypoint.
|
|
3
|
+
|
|
4
|
+
Contains the CLI entrypoint, the `run` function; and first level subcommands `train`
|
|
5
|
+
and `predict`. The `conf` subcommand is added through the `app.add_typer` function, and
|
|
6
|
+
its implementation is contained in the conf.py file.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import typer
|
|
13
|
+
from typing_extensions import Annotated
|
|
14
|
+
|
|
15
|
+
from ..careamist import CAREamist
|
|
16
|
+
from . import conf
|
|
17
|
+
|
|
18
|
+
app = typer.Typer(
|
|
19
|
+
help="Run CAREamics algorithms from the command line, including Noise2Void "
|
|
20
|
+
"and its many variants and cousins"
|
|
21
|
+
)
|
|
22
|
+
app.add_typer(
|
|
23
|
+
conf.app,
|
|
24
|
+
name="conf",
|
|
25
|
+
# callback=conf.conf_options
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@app.command()
|
|
30
|
+
def train( # numpydoc ignore=PR01
|
|
31
|
+
source: Annotated[
|
|
32
|
+
Path,
|
|
33
|
+
typer.Argument(
|
|
34
|
+
help="Path to a configuration file or a trained model.",
|
|
35
|
+
exists=True,
|
|
36
|
+
file_okay=True,
|
|
37
|
+
dir_okay=False,
|
|
38
|
+
),
|
|
39
|
+
],
|
|
40
|
+
train_source: Annotated[
|
|
41
|
+
Path,
|
|
42
|
+
typer.Option(
|
|
43
|
+
"--train-source",
|
|
44
|
+
"-ts",
|
|
45
|
+
help="Path to the training data.",
|
|
46
|
+
exists=True,
|
|
47
|
+
file_okay=True,
|
|
48
|
+
dir_okay=True,
|
|
49
|
+
),
|
|
50
|
+
],
|
|
51
|
+
train_target: Annotated[
|
|
52
|
+
Optional[Path],
|
|
53
|
+
typer.Option(
|
|
54
|
+
"--train-target",
|
|
55
|
+
"-tt",
|
|
56
|
+
help="Path to train target data.",
|
|
57
|
+
exists=True,
|
|
58
|
+
file_okay=True,
|
|
59
|
+
dir_okay=True,
|
|
60
|
+
),
|
|
61
|
+
] = None,
|
|
62
|
+
val_source: Annotated[
|
|
63
|
+
Optional[Path],
|
|
64
|
+
typer.Option(
|
|
65
|
+
"--val-source",
|
|
66
|
+
"-vs",
|
|
67
|
+
help="Path to validation data.",
|
|
68
|
+
exists=True,
|
|
69
|
+
file_okay=True,
|
|
70
|
+
dir_okay=True,
|
|
71
|
+
),
|
|
72
|
+
] = None,
|
|
73
|
+
val_target: Annotated[
|
|
74
|
+
Optional[Path],
|
|
75
|
+
typer.Option(
|
|
76
|
+
"--val-target",
|
|
77
|
+
"-vt",
|
|
78
|
+
help="Path to validation target data.",
|
|
79
|
+
exists=True,
|
|
80
|
+
file_okay=True,
|
|
81
|
+
dir_okay=True,
|
|
82
|
+
),
|
|
83
|
+
] = None,
|
|
84
|
+
use_in_memory: Annotated[
|
|
85
|
+
bool,
|
|
86
|
+
typer.Option(
|
|
87
|
+
"--use-in-memory/--not-in-memory",
|
|
88
|
+
"-m/-M",
|
|
89
|
+
help="Use in memory dataset if possible.",
|
|
90
|
+
),
|
|
91
|
+
] = True,
|
|
92
|
+
val_percentage: Annotated[
|
|
93
|
+
float,
|
|
94
|
+
typer.Option(help="Percentage of files to use for validation."),
|
|
95
|
+
] = 0.1,
|
|
96
|
+
val_minimum_split: Annotated[
|
|
97
|
+
int,
|
|
98
|
+
typer.Option(help="Minimum number of files to use for validation,"),
|
|
99
|
+
] = 1,
|
|
100
|
+
work_dir: Annotated[
|
|
101
|
+
Optional[Path],
|
|
102
|
+
typer.Option(
|
|
103
|
+
"--work-dir",
|
|
104
|
+
"-wd",
|
|
105
|
+
help=("Path to working directory in which to save checkpoints and " "logs"),
|
|
106
|
+
exists=True,
|
|
107
|
+
file_okay=False,
|
|
108
|
+
dir_okay=True,
|
|
109
|
+
),
|
|
110
|
+
] = None,
|
|
111
|
+
):
|
|
112
|
+
"""Train CAREamics models."""
|
|
113
|
+
engine = CAREamist(source=source, work_dir=work_dir)
|
|
114
|
+
engine.train(
|
|
115
|
+
train_source=train_source,
|
|
116
|
+
val_source=val_source,
|
|
117
|
+
train_target=train_target,
|
|
118
|
+
val_target=val_target,
|
|
119
|
+
use_in_memory=use_in_memory,
|
|
120
|
+
val_percentage=val_percentage,
|
|
121
|
+
val_minimum_split=val_minimum_split,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@app.command()
|
|
126
|
+
def predict(): # numpydoc ignore=PR01
|
|
127
|
+
"""Create and save predictions from CAREamics models."""
|
|
128
|
+
# TODO: Need a save predict to workdir function
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def run():
|
|
133
|
+
"""CLI Entry point."""
|
|
134
|
+
app()
|