careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
13
13
  )
14
14
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.config import (
17
- Configuration,
18
- load_configuration,
19
- )
16
+ from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
20
17
  from careamics.config.support import (
21
18
  SupportedAlgorithm,
22
19
  SupportedArchitecture,
@@ -25,7 +22,7 @@ from careamics.config.support import (
25
22
  )
26
23
  from careamics.dataset.dataset_utils import reshape_array
27
24
  from careamics.lightning import (
28
- CAREamicsModule,
25
+ FCNModule,
29
26
  HyperParametersCallback,
30
27
  PredictDataModule,
31
28
  ProgressBarCallback,
@@ -51,8 +48,6 @@ class CAREamist:
51
48
  work_dir : str, optional
52
49
  Path to working directory in which to save checkpoints and logs,
53
50
  by default None.
54
- experiment_name : str, by default "CAREamics"
55
- Experiment name used for checkpoints.
56
51
  callbacks : list of Callback, optional
57
52
  List of callbacks to use during training and prediction, by default None.
58
53
 
@@ -78,8 +73,7 @@ class CAREamist:
78
73
  def __init__( # numpydoc ignore=GL08
79
74
  self,
80
75
  source: Union[Path, str],
81
- work_dir: Optional[str] = None,
82
- experiment_name: str = "CAREamics",
76
+ work_dir: Optional[Union[Path, str]] = None,
83
77
  callbacks: Optional[list[Callback]] = None,
84
78
  ) -> None: ...
85
79
 
@@ -87,8 +81,7 @@ class CAREamist:
87
81
  def __init__( # numpydoc ignore=GL08
88
82
  self,
89
83
  source: Configuration,
90
- work_dir: Optional[str] = None,
91
- experiment_name: str = "CAREamics",
84
+ work_dir: Optional[Union[Path, str]] = None,
92
85
  callbacks: Optional[list[Callback]] = None,
93
86
  ) -> None: ...
94
87
 
@@ -96,7 +89,6 @@ class CAREamist:
96
89
  self,
97
90
  source: Union[Path, str, Configuration],
98
91
  work_dir: Optional[Union[Path, str]] = None,
99
- experiment_name: str = "CAREamics",
100
92
  callbacks: Optional[list[Callback]] = None,
101
93
  ) -> None:
102
94
  """
@@ -109,18 +101,13 @@ class CAREamist:
109
101
 
110
102
  If no working directory is provided, the current working directory is used.
111
103
 
112
- If `source` is a checkpoint, then `experiment_name` is used to name the
113
- checkpoint, and is recorded in the configuration.
114
-
115
104
  Parameters
116
105
  ----------
117
106
  source : pathlib.Path or str or CAREamics Configuration
118
107
  Path to a configuration file or a trained model.
119
- work_dir : str, optional
108
+ work_dir : str or pathlib.Path, optional
120
109
  Path to working directory in which to save checkpoints and logs,
121
110
  by default None.
122
- experiment_name : str, optional
123
- Experiment name used for checkpoints, by default "CAREamics".
124
111
  callbacks : list of Callback, optional
125
112
  List of callbacks to use during training and prediction, by default None.
126
113
 
@@ -148,9 +135,12 @@ class CAREamist:
148
135
  self.cfg = source
149
136
 
150
137
  # instantiate model
151
- self.model = CAREamicsModule(
152
- algorithm_config=self.cfg.algorithm_config,
153
- )
138
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
139
+ self.model = FCNModule(
140
+ algorithm_config=self.cfg.algorithm_config,
141
+ )
142
+ else:
143
+ raise NotImplementedError("Architecture not supported.")
154
144
 
155
145
  # path to configuration file or model
156
146
  else:
@@ -164,9 +154,12 @@ class CAREamist:
164
154
  self.cfg = load_configuration(source)
165
155
 
166
156
  # instantiate model
167
- self.model = CAREamicsModule(
168
- algorithm_config=self.cfg.algorithm_config,
169
- )
157
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
158
+ self.model = FCNModule(
159
+ algorithm_config=self.cfg.algorithm_config,
160
+ ) # type: ignore
161
+ else:
162
+ raise NotImplementedError("Architecture not supported.")
170
163
 
171
164
  # attempt loading a pre-trained model
172
165
  else:
@@ -192,6 +185,13 @@ class CAREamist:
192
185
  # instantiate trainer
193
186
  self.trainer = Trainer(
194
187
  max_epochs=self.cfg.training_config.num_epochs,
188
+ precision=self.cfg.training_config.precision,
189
+ max_steps=self.cfg.training_config.max_steps,
190
+ check_val_every_n_epoch=self.cfg.training_config.check_val_every_n_epoch,
191
+ enable_progress_bar=self.cfg.training_config.enable_progress_bar,
192
+ accumulate_grad_batches=self.cfg.training_config.accumulate_grad_batches,
193
+ gradient_clip_val=self.cfg.training_config.gradient_clip_val,
194
+ gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
195
195
  callbacks=self.callbacks,
196
196
  default_root_dir=self.work_dir,
197
197
  logger=self.experiment_logger,
@@ -247,6 +247,12 @@ class CAREamist:
247
247
  EarlyStopping(self.cfg.training_config.early_stopping_callback)
248
248
  )
249
249
 
250
+ def stop_training(self) -> None:
251
+ """Stop the training loop."""
252
+ # raise stop training flag
253
+ self.trainer.should_stop = True
254
+ self.trainer.limit_val_batches = 0 # skip validation
255
+
250
256
  # TODO: is there are more elegant way than calling train again after _train_on_paths
251
257
  def train(
252
258
  self,
@@ -393,9 +399,14 @@ class CAREamist:
393
399
  datamodule : TrainDataModule
394
400
  Datamodule to train on.
395
401
  """
396
- # record datamodule
402
+ # register datamodule
397
403
  self.train_datamodule = datamodule
398
404
 
405
+ # set defaults (in case `stop_training` was called before)
406
+ self.trainer.should_stop = False
407
+ self.trainer.limit_val_batches = 1.0 # 100%
408
+
409
+ # train
399
410
  self.trainer.fit(self.model, datamodule=datamodule)
400
411
 
401
412
  def _train_on_array(
@@ -511,7 +522,7 @@ class CAREamist:
511
522
  tile_overlap: tuple[int, ...] = (48, 48),
512
523
  axes: Optional[str] = None,
513
524
  data_type: Optional[Literal["tiff", "custom"]] = None,
514
- tta_transforms: bool = True,
525
+ tta_transforms: bool = False,
515
526
  dataloader_params: Optional[dict] = None,
516
527
  read_source_func: Optional[Callable] = None,
517
528
  extension_filter: str = "",
@@ -527,7 +538,7 @@ class CAREamist:
527
538
  tile_overlap: tuple[int, ...] = (48, 48),
528
539
  axes: Optional[str] = None,
529
540
  data_type: Optional[Literal["array"]] = None,
530
- tta_transforms: bool = True,
541
+ tta_transforms: bool = False,
531
542
  dataloader_params: Optional[dict] = None,
532
543
  ) -> Union[list[NDArray], NDArray]: ...
533
544
 
@@ -540,7 +551,7 @@ class CAREamist:
540
551
  tile_overlap: Optional[tuple[int, ...]] = (48, 48),
541
552
  axes: Optional[str] = None,
542
553
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
543
- tta_transforms: bool = True,
554
+ tta_transforms: bool = False,
544
555
  dataloader_params: Optional[dict] = None,
545
556
  read_source_func: Optional[Callable] = None,
546
557
  extension_filter: str = "",
@@ -0,0 +1,5 @@
1
+ """
2
+ Package containing functions called by the careamics cli.
3
+
4
+ Built using third party package Typer.
5
+ """
careamics/cli/conf.py ADDED
@@ -0,0 +1,391 @@
1
+ """Configuration building convenience functions for the CAREamics CLI."""
2
+
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import click
9
+ import typer
10
+ import yaml
11
+ from typing_extensions import Annotated
12
+
13
+ from ..config import (
14
+ Configuration,
15
+ create_care_configuration,
16
+ create_n2n_configuration,
17
+ create_n2v_configuration,
18
+ save_configuration,
19
+ )
20
+
21
+ WORK_DIR = Path.cwd()
22
+
23
+ app = typer.Typer()
24
+
25
+
26
+ def _config_builder_exit(ctx: typer.Context, config: Configuration) -> None:
27
+ """
28
+ Function to be called at the end of a CLI configuration builder.
29
+
30
+ Saves the `config` object and performs other functionality depending on the command
31
+ context.
32
+
33
+ Parameters
34
+ ----------
35
+ ctx : typer.Context
36
+ Typer Context.
37
+ config : Configuration
38
+ CAREamics configuration.
39
+ """
40
+ conf_path = (ctx.obj.dir / ctx.obj.name).with_suffix(".yaml")
41
+ save_configuration(config, conf_path)
42
+ if ctx.obj.print:
43
+ print(yaml.dump(config.model_dump(), indent=2))
44
+
45
+
46
+ @dataclass
47
+ class ConfOptions:
48
+ """Data class for containing CLI `conf` command option values."""
49
+
50
+ dir: Path
51
+ name: str
52
+ force: bool
53
+ print: bool
54
+
55
+
56
+ @app.callback()
57
+ def conf_options( # numpydoc ignore=PR01
58
+ ctx: typer.Context,
59
+ dir: Annotated[
60
+ Path,
61
+ typer.Option(
62
+ "--dir", "-d", exists=True, help="Directory to save the config file to."
63
+ ),
64
+ ] = WORK_DIR,
65
+ name: Annotated[
66
+ str, typer.Option("--name", "-n", help="The config file name.")
67
+ ] = "config",
68
+ force: Annotated[
69
+ bool,
70
+ typer.Option(
71
+ "--force", "-f", help="Whether to overwrite existing config files."
72
+ ),
73
+ ] = False,
74
+ print: Annotated[
75
+ bool,
76
+ typer.Option(
77
+ "--print",
78
+ "-p",
79
+ help="Whether to print the config file to the console.",
80
+ ),
81
+ ] = False,
82
+ ):
83
+ """Build and save CAREamics configuration files."""
84
+ # Callback is called still on --help command
85
+ # If a config exists it will complain that you need to use the -f flag
86
+ if "--help" in sys.argv:
87
+ return
88
+ conf_path = (dir / name).with_suffix(".yaml")
89
+ if conf_path.exists() and not force:
90
+ raise FileExistsError(f"To overwrite '{conf_path}' use flag --force/-f.")
91
+
92
+ ctx.obj = ConfOptions(dir, name, force, print)
93
+
94
+
95
+ def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
96
+ """
97
+ Callback for --patch-size option.
98
+
99
+ Parameters
100
+ ----------
101
+ value : (int, int, int)
102
+ Patch size value.
103
+
104
+ Returns
105
+ -------
106
+ (int, int, int) | (int, int)
107
+ If the last element in `value` is -1 the tuple is reduced to the first two
108
+ values.
109
+ """
110
+ if value[2] == -1:
111
+ return value[:2]
112
+ return value
113
+
114
+
115
+ # TODO: Need to decide how to parse model kwargs
116
+ # - Could be json style string to be loaded as dict e.g. {"depth": 3}
117
+ # - Cons: Annoying to type, easily have syntax errors
118
+ # - Could parse all unknown options as model kwargs
119
+ # - Cons: There could be argument name clashes
120
+
121
+
122
+ @app.command()
123
+ def care( # numpydoc ignore=PR01
124
+ ctx: typer.Context,
125
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
126
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
127
+ patch_size: Annotated[
128
+ click.Tuple,
129
+ typer.Option(
130
+ help=(
131
+ "Size of the patches along the spatial dimensions (if the data "
132
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
133
+ ),
134
+ click_type=click.Tuple([int, int, int]),
135
+ callback=patch_size_callback,
136
+ ),
137
+ ],
138
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
139
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
140
+ data_type: Annotated[
141
+ click.Choice,
142
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
143
+ ] = "tiff",
144
+ use_augmentations: Annotated[
145
+ bool, typer.Option(help="Whether to use augmentations.")
146
+ ] = True,
147
+ independent_channels: Annotated[
148
+ bool, typer.Option(help="Whether to train all channels independently.")
149
+ ] = False,
150
+ loss: Annotated[
151
+ click.Choice,
152
+ typer.Option(
153
+ click_type=click.Choice(["mae", "mse"]),
154
+ help="Loss function to use.",
155
+ ),
156
+ ] = "mae",
157
+ n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
158
+ n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
159
+ logger: Annotated[
160
+ click.Choice,
161
+ typer.Option(
162
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
163
+ help="Logger to use.",
164
+ ),
165
+ ] = "none",
166
+ # TODO: How to address model kwargs
167
+ ) -> None:
168
+ """
169
+ Create a configuration for training CARE.
170
+
171
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
172
+ 2.
173
+
174
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
175
+ channels. Likewise, if you set the number of channels, then "C" must be present in
176
+ `axes`.
177
+
178
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
179
+ not specified, it will be assumed to be equal to `n_channels_in`.
180
+
181
+ By default, all channels are trained together. To train all channels independently,
182
+ set `independent_channels` to True.
183
+
184
+ By setting `use_augmentations` to False, the only transformation applied will be
185
+ normalization.
186
+ """
187
+ config = create_care_configuration(
188
+ experiment_name=experiment_name,
189
+ data_type=data_type,
190
+ axes=axes,
191
+ patch_size=patch_size,
192
+ batch_size=batch_size,
193
+ num_epochs=num_epochs,
194
+ # TODO: fix choosing augmentations
195
+ augmentations=None if use_augmentations else [],
196
+ independent_channels=independent_channels,
197
+ loss=loss,
198
+ n_channels_in=n_channels_in,
199
+ n_channels_out=n_channels_out,
200
+ logger=logger,
201
+ )
202
+ _config_builder_exit(ctx, config)
203
+
204
+
205
+ @app.command()
206
+ def n2n( # numpydoc ignore=PR01
207
+ ctx: typer.Context,
208
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
209
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
210
+ patch_size: Annotated[
211
+ click.Tuple,
212
+ typer.Option(
213
+ help=(
214
+ "Size of the patches along the spatial dimensions (if the data "
215
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
216
+ ),
217
+ click_type=click.Tuple([int, int, int]),
218
+ callback=patch_size_callback,
219
+ ),
220
+ ],
221
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
222
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
223
+ data_type: Annotated[
224
+ click.Choice,
225
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
226
+ ] = "tiff",
227
+ use_augmentations: Annotated[
228
+ bool, typer.Option(help="Whether to use augmentations.")
229
+ ] = True,
230
+ independent_channels: Annotated[
231
+ bool, typer.Option(help="Whether to train all channels independently.")
232
+ ] = False,
233
+ loss: Annotated[
234
+ click.Choice,
235
+ typer.Option(
236
+ click_type=click.Choice(["mae", "mse"]),
237
+ help="Loss function to use.",
238
+ ),
239
+ ] = "mae",
240
+ n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
241
+ n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
242
+ logger: Annotated[
243
+ click.Choice,
244
+ typer.Option(
245
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
246
+ help="Logger to use.",
247
+ ),
248
+ ] = "none",
249
+ # TODO: How to address model kwargs
250
+ ) -> None:
251
+ """
252
+ Create a configuration for training Noise2Noise.
253
+
254
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
255
+ 2.
256
+
257
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
258
+ channels. Likewise, if you set the number of channels, then "C" must be present in
259
+ `axes`.
260
+
261
+ By default, all channels are trained together. To train all channels independently,
262
+ set `independent_channels` to True.
263
+
264
+ By setting `use_augmentations` to False, the only transformation applied will be
265
+ normalization.
266
+ """
267
+ config = create_n2n_configuration(
268
+ experiment_name=experiment_name,
269
+ data_type=data_type,
270
+ axes=axes,
271
+ patch_size=patch_size,
272
+ batch_size=batch_size,
273
+ num_epochs=num_epochs,
274
+ # TODO: fix choosing augmentations
275
+ augmentations=None if use_augmentations else [],
276
+ independent_channels=independent_channels,
277
+ loss=loss,
278
+ n_channels_in=n_channels_in,
279
+ n_channels_out=n_channels_out,
280
+ logger=logger,
281
+ )
282
+ _config_builder_exit(ctx, config)
283
+
284
+
285
+ @app.command()
286
+ def n2v( # numpydoc ignore=PR01
287
+ ctx: typer.Context,
288
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
289
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
290
+ patch_size: Annotated[
291
+ click.Tuple,
292
+ typer.Option(
293
+ help=(
294
+ "Size of the patches along the spatial dimensions (if the data "
295
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
296
+ ),
297
+ click_type=click.Tuple([int, int, int]),
298
+ callback=patch_size_callback,
299
+ ),
300
+ ],
301
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
302
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")],
303
+ data_type: Annotated[
304
+ click.Choice,
305
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
306
+ ] = "tiff",
307
+ use_augmentations: Annotated[
308
+ bool, typer.Option(help="Whether to use augmentations.")
309
+ ] = True,
310
+ independent_channels: Annotated[
311
+ bool, typer.Option(help="Whether to train all channels independently.")
312
+ ] = True,
313
+ use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
314
+ n_channels: Annotated[
315
+ int, typer.Option(help="Number of channels (in and out)")
316
+ ] = 1,
317
+ roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
318
+ masked_pixel_percentage: Annotated[
319
+ float, typer.Option(help="Percentage of pixels masked in each patch.")
320
+ ] = 0.2,
321
+ struct_n2v_axis: Annotated[
322
+ click.Choice,
323
+ typer.Option(click_type=click.Choice(["horizontal", "vertical", "none"])),
324
+ ] = "none",
325
+ struct_n2v_span: Annotated[
326
+ int, typer.Option(help="Span of the structN2V mask.")
327
+ ] = 5,
328
+ logger: Annotated[
329
+ click.Choice,
330
+ typer.Option(
331
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
332
+ help="Logger to use.",
333
+ ),
334
+ ] = "none",
335
+ # TODO: How to address model kwargs
336
+ ) -> None:
337
+ """
338
+ Create a configuration for training Noise2Void.
339
+
340
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
341
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
342
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
343
+
344
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
345
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
346
+ or horizontal correlations are present in the noise; it applies an additional mask
347
+ to the manipulated pixel neighbors.
348
+
349
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
350
+ 2.
351
+
352
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
353
+ channels.
354
+
355
+ By default, all channels are trained independently. To train all channels together,
356
+ set `independent_channels` to False.
357
+
358
+ By setting `use_augmentations` to False, the only transformations applied will be
359
+ normalization and N2V manipulation.
360
+
361
+ The `roi_size` parameter specifies the size of the area around each pixel that will
362
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
363
+ pixels per patch will be manipulated.
364
+
365
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
366
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
367
+ corresponding parameters passed in `model_kwargs`.
368
+
369
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
370
+ will be applied to each manipulated pixel.
371
+ """
372
+ config = create_n2v_configuration(
373
+ experiment_name=experiment_name,
374
+ data_type=data_type,
375
+ axes=axes,
376
+ patch_size=patch_size,
377
+ batch_size=batch_size,
378
+ num_epochs=num_epochs,
379
+ # TODO: fix choosing augmentations
380
+ augmentations=None if use_augmentations else [],
381
+ independent_channels=independent_channels,
382
+ use_n2v2=use_n2v2,
383
+ n_channels=n_channels,
384
+ roi_size=roi_size,
385
+ masked_pixel_percentage=masked_pixel_percentage,
386
+ struct_n2v_axis=struct_n2v_axis,
387
+ struct_n2v_span=struct_n2v_span,
388
+ logger=logger,
389
+ # TODO: Model kwargs
390
+ )
391
+ _config_builder_exit(ctx, config)
careamics/cli/main.py ADDED
@@ -0,0 +1,134 @@
1
+ """
2
+ Module for CLI functionality and entrypoint.
3
+
4
+ Contains the CLI entrypoint, the `run` function; and first level subcommands `train`
5
+ and `predict`. The `conf` subcommand is added through the `app.add_typer` function, and
6
+ its implementation is contained in the conf.py file.
7
+ """
8
+
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import typer
13
+ from typing_extensions import Annotated
14
+
15
+ from ..careamist import CAREamist
16
+ from . import conf
17
+
18
+ app = typer.Typer(
19
+ help="Run CAREamics algorithms from the command line, including Noise2Void "
20
+ "and its many variants and cousins"
21
+ )
22
+ app.add_typer(
23
+ conf.app,
24
+ name="conf",
25
+ # callback=conf.conf_options
26
+ )
27
+
28
+
29
+ @app.command()
30
+ def train( # numpydoc ignore=PR01
31
+ source: Annotated[
32
+ Path,
33
+ typer.Argument(
34
+ help="Path to a configuration file or a trained model.",
35
+ exists=True,
36
+ file_okay=True,
37
+ dir_okay=False,
38
+ ),
39
+ ],
40
+ train_source: Annotated[
41
+ Path,
42
+ typer.Option(
43
+ "--train-source",
44
+ "-ts",
45
+ help="Path to the training data.",
46
+ exists=True,
47
+ file_okay=True,
48
+ dir_okay=True,
49
+ ),
50
+ ],
51
+ train_target: Annotated[
52
+ Optional[Path],
53
+ typer.Option(
54
+ "--train-target",
55
+ "-tt",
56
+ help="Path to train target data.",
57
+ exists=True,
58
+ file_okay=True,
59
+ dir_okay=True,
60
+ ),
61
+ ] = None,
62
+ val_source: Annotated[
63
+ Optional[Path],
64
+ typer.Option(
65
+ "--val-source",
66
+ "-vs",
67
+ help="Path to validation data.",
68
+ exists=True,
69
+ file_okay=True,
70
+ dir_okay=True,
71
+ ),
72
+ ] = None,
73
+ val_target: Annotated[
74
+ Optional[Path],
75
+ typer.Option(
76
+ "--val-target",
77
+ "-vt",
78
+ help="Path to validation target data.",
79
+ exists=True,
80
+ file_okay=True,
81
+ dir_okay=True,
82
+ ),
83
+ ] = None,
84
+ use_in_memory: Annotated[
85
+ bool,
86
+ typer.Option(
87
+ "--use-in-memory/--not-in-memory",
88
+ "-m/-M",
89
+ help="Use in memory dataset if possible.",
90
+ ),
91
+ ] = True,
92
+ val_percentage: Annotated[
93
+ float,
94
+ typer.Option(help="Percentage of files to use for validation."),
95
+ ] = 0.1,
96
+ val_minimum_split: Annotated[
97
+ int,
98
+ typer.Option(help="Minimum number of files to use for validation,"),
99
+ ] = 1,
100
+ work_dir: Annotated[
101
+ Optional[Path],
102
+ typer.Option(
103
+ "--work-dir",
104
+ "-wd",
105
+ help=("Path to working directory in which to save checkpoints and " "logs"),
106
+ exists=True,
107
+ file_okay=False,
108
+ dir_okay=True,
109
+ ),
110
+ ] = None,
111
+ ):
112
+ """Train CAREamics models."""
113
+ engine = CAREamist(source=source, work_dir=work_dir)
114
+ engine.train(
115
+ train_source=train_source,
116
+ val_source=val_source,
117
+ train_target=train_target,
118
+ val_target=val_target,
119
+ use_in_memory=use_in_memory,
120
+ val_percentage=val_percentage,
121
+ val_minimum_split=val_minimum_split,
122
+ )
123
+
124
+
125
+ @app.command()
126
+ def predict(): # numpydoc ignore=PR01
127
+ """Create and save predictions from CAREamics models."""
128
+ # TODO: Need a save predict to workdir function
129
+ raise NotImplementedError
130
+
131
+
132
+ def run():
133
+ """CLI Entry point."""
134
+ app()