careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,123 @@
1
+ """Callback Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import timedelta
6
+ from typing import Literal, Optional
7
+
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ )
13
+
14
+
15
+ class CheckpointModel(BaseModel):
16
+ """Checkpoint saving callback Pydantic model.
17
+
18
+ The parameters corresponds to those of
19
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
20
+
21
+ See:
22
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
+ """
24
+
25
+ model_config = ConfigDict(
26
+ validate_assignment=True,
27
+ )
28
+
29
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
+ """Quantity to monitor."""
31
+
32
+ verbose: bool = Field(default=False, validate_default=True)
33
+ """Verbosity mode."""
34
+
35
+ save_weights_only: bool = Field(default=False, validate_default=True)
36
+ """When `True`, only the model's weights will be saved (model.save_weights)."""
37
+
38
+ save_last: Optional[Literal[True, False, "link"]] = Field(
39
+ default=True, validate_default=True
40
+ )
41
+ """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
+
43
+ save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
44
+ """If `save_top_k == kz, the best k models according to the quantity monitored
45
+ will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
+ all models are saved."""
47
+
48
+ mode: Literal["min", "max"] = Field(default="min", validate_default=True)
49
+ """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
+ save file is made based on either the maximization or the minimization of the
51
+ monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
+ be 'min', etc.
53
+ """
54
+
55
+ auto_insert_metric_name: bool = Field(default=False, validate_default=True)
56
+ """When `True`, the checkpoints filenames will contain the metric name."""
57
+
58
+ every_n_train_steps: Optional[int] = Field(
59
+ default=None, ge=1, le=10, validate_default=True
60
+ )
61
+ """Number of training steps between checkpoints."""
62
+
63
+ train_time_interval: Optional[timedelta] = Field(
64
+ default=None, validate_default=True
65
+ )
66
+ """Checkpoints are monitored at the specified time interval."""
67
+
68
+ every_n_epochs: Optional[int] = Field(
69
+ default=None, ge=1, le=10, validate_default=True
70
+ )
71
+ """Number of epochs between checkpoints."""
72
+
73
+
74
+ class EarlyStoppingModel(BaseModel):
75
+ """Early stopping callback Pydantic model.
76
+
77
+ The parameters corresponds to those of
78
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
79
+
80
+ See:
81
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
82
+ """
83
+
84
+ model_config = ConfigDict(
85
+ validate_assignment=True,
86
+ )
87
+
88
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
89
+ """Quantity to monitor."""
90
+
91
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
92
+ """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
+ absolute change of less than or equal to min_delta, will count as no improvement."""
94
+
95
+ patience: int = Field(default=3, ge=1, le=10, validate_default=True)
96
+ """Number of checks with no improvement after which training will be stopped."""
97
+
98
+ verbose: bool = Field(default=False, validate_default=True)
99
+ """Verbosity mode."""
100
+
101
+ mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
102
+ """One of {min, max, auto}."""
103
+
104
+ check_finite: bool = Field(default=True, validate_default=True)
105
+ """When `True`, stops training when the monitored quantity becomes `NaN` or
106
+ `inf`."""
107
+
108
+ stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
109
+ """Stop training immediately once the monitored quantity reaches this threshold."""
110
+
111
+ divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
112
+ """Stop training as soon as the monitored quantity becomes worse than this
113
+ threshold."""
114
+
115
+ check_on_train_epoch_end: Optional[bool] = Field(
116
+ default=False, validate_default=True
117
+ )
118
+ """Whether to run early stopping at the end of the training epoch. If this is
119
+ `False`, then the check runs at the end of the validation."""
120
+
121
+ log_rank_zero_only: bool = Field(default=False, validate_default=True)
122
+ """When set `True`, logs the status of the early stopping callback only for rank 0
123
+ process."""
@@ -0,0 +1,583 @@
1
+ """Convenience functions to create configurations for training and inference."""
2
+
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ from .architectures import UNetModel
6
+ from .configuration_model import Configuration
7
+ from .data_model import DataConfig
8
+ from .fcn_algorithm_model import FCNAlgorithmConfig
9
+ from .support import (
10
+ SupportedAlgorithm,
11
+ SupportedArchitecture,
12
+ SupportedLoss,
13
+ SupportedPixelManipulation,
14
+ SupportedTransform,
15
+ )
16
+ from .training_model import TrainingConfig
17
+
18
+
19
+ # TODO rename ?
20
+ def _create_supervised_configuration(
21
+ algorithm_type: Literal["fcn"],
22
+ algorithm: Literal["care", "n2n"],
23
+ experiment_name: str,
24
+ data_type: Literal["array", "tiff", "custom"],
25
+ axes: str,
26
+ patch_size: List[int],
27
+ batch_size: int,
28
+ num_epochs: int,
29
+ use_augmentations: bool = True,
30
+ independent_channels: bool = False,
31
+ loss: Literal["mae", "mse"] = "mae",
32
+ n_channels_in: int = 1,
33
+ n_channels_out: int = 1,
34
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
35
+ model_kwargs: Optional[dict] = None,
36
+ ) -> Configuration:
37
+ """
38
+ Create a configuration for training CARE or Noise2Noise.
39
+
40
+ Parameters
41
+ ----------
42
+ algorithm_type : Literal["fcn"]
43
+ Type of the algorithm.
44
+ algorithm : Literal["care", "n2n"]
45
+ Algorithm to use.
46
+ experiment_name : str
47
+ Name of the experiment.
48
+ data_type : Literal["array", "tiff", "custom"]
49
+ Type of the data.
50
+ axes : str
51
+ Axes of the data (e.g. SYX).
52
+ patch_size : List[int]
53
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
54
+ batch_size : int
55
+ Batch size.
56
+ num_epochs : int
57
+ Number of epochs.
58
+ use_augmentations : bool, optional
59
+ Whether to use augmentations, by default True.
60
+ independent_channels : bool, optional
61
+ Whether to train all channels independently, by default False.
62
+ loss : Literal["mae", "mse"], optional
63
+ Loss function to use, by default "mae".
64
+ n_channels_in : int, optional
65
+ Number of channels in, by default 1.
66
+ n_channels_out : int, optional
67
+ Number of channels out, by default 1.
68
+ logger : Literal["wandb", "tensorboard", "none"], optional
69
+ Logger to use, by default "none".
70
+ model_kwargs : dict, optional
71
+ UNetModel parameters, by default {}.
72
+
73
+ Returns
74
+ -------
75
+ Configuration
76
+ Configuration for training CARE or Noise2Noise.
77
+ """
78
+ # if there are channels, we need to specify their number
79
+ if "C" in axes and n_channels_in == 1:
80
+ raise ValueError(
81
+ f"Number of channels in must be specified when using channels "
82
+ f"(got {n_channels_in} channel)."
83
+ )
84
+ elif "C" not in axes and n_channels_in > 1:
85
+ raise ValueError(
86
+ f"C is not present in the axes, but number of channels is specified "
87
+ f"(got {n_channels_in} channels)."
88
+ )
89
+
90
+ # model
91
+ if model_kwargs is None:
92
+ model_kwargs = {}
93
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
94
+ model_kwargs["in_channels"] = n_channels_in
95
+ model_kwargs["num_classes"] = n_channels_out
96
+ model_kwargs["independent_channels"] = independent_channels
97
+
98
+ unet_model = UNetModel(
99
+ architecture=SupportedArchitecture.UNET.value,
100
+ **model_kwargs,
101
+ )
102
+
103
+ # algorithm model
104
+ algorithm = FCNAlgorithmConfig(
105
+ algorithm_type=algorithm_type,
106
+ algorithm=algorithm,
107
+ loss=loss,
108
+ model=unet_model,
109
+ )
110
+
111
+ # augmentations
112
+ if use_augmentations:
113
+ transforms: List[Dict[str, Any]] = [
114
+ {
115
+ "name": SupportedTransform.XY_FLIP.value,
116
+ },
117
+ {
118
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
119
+ },
120
+ ]
121
+ else:
122
+ transforms = []
123
+
124
+ # data model
125
+ data = DataConfig(
126
+ data_type=data_type,
127
+ axes=axes,
128
+ patch_size=patch_size,
129
+ batch_size=batch_size,
130
+ transforms=transforms,
131
+ )
132
+
133
+ # training model
134
+ training = TrainingConfig(
135
+ num_epochs=num_epochs,
136
+ batch_size=batch_size,
137
+ logger=None if logger == "none" else logger,
138
+ )
139
+
140
+ # create configuration
141
+ configuration = Configuration(
142
+ experiment_name=experiment_name,
143
+ algorithm_config=algorithm,
144
+ data_config=data,
145
+ training_config=training,
146
+ )
147
+
148
+ return configuration
149
+
150
+
151
+ def create_care_configuration(
152
+ experiment_name: str,
153
+ data_type: Literal["array", "tiff", "custom"],
154
+ axes: str,
155
+ patch_size: List[int],
156
+ batch_size: int,
157
+ num_epochs: int,
158
+ use_augmentations: bool = True,
159
+ independent_channels: bool = False,
160
+ loss: Literal["mae", "mse"] = "mae",
161
+ n_channels_in: int = 1,
162
+ n_channels_out: int = -1,
163
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
164
+ model_kwargs: Optional[dict] = None,
165
+ ) -> Configuration:
166
+ """
167
+ Create a configuration for training CARE.
168
+
169
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
170
+ 2.
171
+
172
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
173
+ channels. Likewise, if you set the number of channels, then "C" must be present in
174
+ `axes`.
175
+
176
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
177
+ not specified, it will be assumed to be equal to `n_channels_in`.
178
+
179
+ By default, all channels are trained together. To train all channels independently,
180
+ set `independent_channels` to True.
181
+
182
+ By setting `use_augmentations` to False, the only transformation applied will be
183
+ normalization.
184
+
185
+ Parameters
186
+ ----------
187
+ experiment_name : str
188
+ Name of the experiment.
189
+ data_type : Literal["array", "tiff", "custom"]
190
+ Type of the data.
191
+ axes : str
192
+ Axes of the data (e.g. SYX).
193
+ patch_size : List[int]
194
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
195
+ batch_size : int
196
+ Batch size.
197
+ num_epochs : int
198
+ Number of epochs.
199
+ use_augmentations : bool, optional
200
+ Whether to use augmentations, by default True.
201
+ independent_channels : bool, optional
202
+ Whether to train all channels independently, by default False.
203
+ loss : Literal["mae", "mse"], optional
204
+ Loss function to use, by default "mae".
205
+ n_channels_in : int, optional
206
+ Number of channels in, by default 1.
207
+ n_channels_out : int, optional
208
+ Number of channels out, by default -1.
209
+ logger : Literal["wandb", "tensorboard", "none"], optional
210
+ Logger to use, by default "none".
211
+ model_kwargs : dict, optional
212
+ UNetModel parameters, by default {}.
213
+
214
+ Returns
215
+ -------
216
+ Configuration
217
+ Configuration for training CARE.
218
+ """
219
+ if n_channels_out == -1:
220
+ n_channels_out = n_channels_in
221
+
222
+ return _create_supervised_configuration(
223
+ algorithm_type="fcn",
224
+ algorithm="care",
225
+ experiment_name=experiment_name,
226
+ data_type=data_type,
227
+ axes=axes,
228
+ patch_size=patch_size,
229
+ batch_size=batch_size,
230
+ num_epochs=num_epochs,
231
+ use_augmentations=use_augmentations,
232
+ independent_channels=independent_channels,
233
+ loss=loss,
234
+ n_channels_in=n_channels_in,
235
+ n_channels_out=n_channels_out,
236
+ logger=logger,
237
+ model_kwargs=model_kwargs,
238
+ )
239
+
240
+
241
+ def create_n2n_configuration(
242
+ experiment_name: str,
243
+ data_type: Literal["array", "tiff", "custom"],
244
+ axes: str,
245
+ patch_size: List[int],
246
+ batch_size: int,
247
+ num_epochs: int,
248
+ use_augmentations: bool = True,
249
+ independent_channels: bool = False,
250
+ loss: Literal["mae", "mse"] = "mae",
251
+ n_channels_in: int = 1,
252
+ n_channels_out: int = -1,
253
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
254
+ model_kwargs: Optional[dict] = None,
255
+ ) -> Configuration:
256
+ """
257
+ Create a configuration for training Noise2Noise.
258
+
259
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
260
+ 2.
261
+
262
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
263
+ channels. Likewise, if you set the number of channels, then "C" must be present in
264
+ `axes`.
265
+
266
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
267
+ not specified, it will be assumed to be equal to `n_channels_in`.
268
+
269
+ By default, all channels are trained together. To train all channels independently,
270
+ set `independent_channels` to True.
271
+
272
+ By setting `use_augmentations` to False, the only transformation applied will be
273
+ normalization.
274
+
275
+ Parameters
276
+ ----------
277
+ experiment_name : str
278
+ Name of the experiment.
279
+ data_type : Literal["array", "tiff", "custom"]
280
+ Type of the data.
281
+ axes : str
282
+ Axes of the data (e.g. SYX).
283
+ patch_size : List[int]
284
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
285
+ batch_size : int
286
+ Batch size.
287
+ num_epochs : int
288
+ Number of epochs.
289
+ use_augmentations : bool, optional
290
+ Whether to use augmentations, by default True.
291
+ independent_channels : bool, optional
292
+ Whether to train all channels independently, by default False.
293
+ loss : Literal["mae", "mse"], optional
294
+ Loss function to use, by default "mae".
295
+ n_channels_in : int, optional
296
+ Number of channels in, by default 1.
297
+ n_channels_out : int, optional
298
+ Number of channels out, by default -1.
299
+ logger : Literal["wandb", "tensorboard", "none"], optional
300
+ Logger to use, by default "none".
301
+ model_kwargs : dict, optional
302
+ UNetModel parameters, by default {}.
303
+
304
+ Returns
305
+ -------
306
+ Configuration
307
+ Configuration for training Noise2Noise.
308
+ """
309
+ if n_channels_out == -1:
310
+ n_channels_out = n_channels_in
311
+
312
+ return _create_supervised_configuration(
313
+ algorithm_type="fcn",
314
+ algorithm="n2n",
315
+ experiment_name=experiment_name,
316
+ data_type=data_type,
317
+ axes=axes,
318
+ patch_size=patch_size,
319
+ batch_size=batch_size,
320
+ num_epochs=num_epochs,
321
+ use_augmentations=use_augmentations,
322
+ independent_channels=independent_channels,
323
+ loss=loss,
324
+ n_channels_in=n_channels_in,
325
+ n_channels_out=n_channels_out,
326
+ logger=logger,
327
+ model_kwargs=model_kwargs,
328
+ )
329
+
330
+
331
+ def create_n2v_configuration(
332
+ experiment_name: str,
333
+ data_type: Literal["array", "tiff", "custom"],
334
+ axes: str,
335
+ patch_size: List[int],
336
+ batch_size: int,
337
+ num_epochs: int,
338
+ use_augmentations: bool = True,
339
+ independent_channels: bool = True,
340
+ use_n2v2: bool = False,
341
+ n_channels: int = 1,
342
+ roi_size: int = 11,
343
+ masked_pixel_percentage: float = 0.2,
344
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
345
+ struct_n2v_span: int = 5,
346
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
347
+ model_kwargs: Optional[dict] = None,
348
+ ) -> Configuration:
349
+ """
350
+ Create a configuration for training Noise2Void.
351
+
352
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
353
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
354
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
355
+
356
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
357
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
358
+ or horizontal correlations are present in the noise; it applies an additional mask
359
+ to the manipulated pixel neighbors.
360
+
361
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
362
+ 2.
363
+
364
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
365
+ channels.
366
+
367
+ By default, all channels are trained independently. To train all channels together,
368
+ set `independent_channels` to False.
369
+
370
+ By setting `use_augmentations` to False, the only transformations applied will be
371
+ normalization and N2V manipulation.
372
+
373
+ The `roi_size` parameter specifies the size of the area around each pixel that will
374
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
375
+ pixels per patch will be manipulated.
376
+
377
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
378
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
379
+ corresponding parameters passed in `model_kwargs`.
380
+
381
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
382
+ will be applied to each manipulated pixel.
383
+
384
+ Parameters
385
+ ----------
386
+ experiment_name : str
387
+ Name of the experiment.
388
+ data_type : Literal["array", "tiff", "custom"]
389
+ Type of the data.
390
+ axes : str
391
+ Axes of the data (e.g. SYX).
392
+ patch_size : List[int]
393
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
394
+ batch_size : int
395
+ Batch size.
396
+ num_epochs : int
397
+ Number of epochs.
398
+ use_augmentations : bool, optional
399
+ Whether to use augmentations, by default True.
400
+ independent_channels : bool, optional
401
+ Whether to train all channels together, by default True.
402
+ use_n2v2 : bool, optional
403
+ Whether to use N2V2, by default False.
404
+ n_channels : int, optional
405
+ Number of channels (in and out), by default 1.
406
+ roi_size : int, optional
407
+ N2V pixel manipulation area, by default 11.
408
+ masked_pixel_percentage : float, optional
409
+ Percentage of pixels masked in each patch, by default 0.2.
410
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
411
+ Axis along which to apply structN2V mask, by default "none".
412
+ struct_n2v_span : int, optional
413
+ Span of the structN2V mask, by default 5.
414
+ logger : Literal["wandb", "tensorboard", "none"], optional
415
+ Logger to use, by default "none".
416
+ model_kwargs : dict, optional
417
+ UNetModel parameters, by default {}.
418
+
419
+ Returns
420
+ -------
421
+ Configuration
422
+ Configuration for training N2V.
423
+
424
+ Examples
425
+ --------
426
+ Minimum example:
427
+ >>> config = create_n2v_configuration(
428
+ ... experiment_name="n2v_experiment",
429
+ ... data_type="array",
430
+ ... axes="YX",
431
+ ... patch_size=[64, 64],
432
+ ... batch_size=32,
433
+ ... num_epochs=100
434
+ ... )
435
+
436
+ To use N2V2, simply pass the `use_n2v2` parameter:
437
+ >>> config = create_n2v_configuration(
438
+ ... experiment_name="n2v2_experiment",
439
+ ... data_type="tiff",
440
+ ... axes="YX",
441
+ ... patch_size=[64, 64],
442
+ ... batch_size=32,
443
+ ... num_epochs=100,
444
+ ... use_n2v2=True
445
+ ... )
446
+
447
+ For structN2V, there are two parameters to set, `struct_n2v_axis` and
448
+ `struct_n2v_span`:
449
+ >>> config = create_n2v_configuration(
450
+ ... experiment_name="structn2v_experiment",
451
+ ... data_type="tiff",
452
+ ... axes="YX",
453
+ ... patch_size=[64, 64],
454
+ ... batch_size=32,
455
+ ... num_epochs=100,
456
+ ... struct_n2v_axis="horizontal",
457
+ ... struct_n2v_span=7
458
+ ... )
459
+
460
+ If you are training multiple channels independently, then you need to specify the
461
+ number of channels:
462
+ >>> config = create_n2v_configuration(
463
+ ... experiment_name="n2v_experiment",
464
+ ... data_type="array",
465
+ ... axes="YXC",
466
+ ... patch_size=[64, 64],
467
+ ... batch_size=32,
468
+ ... num_epochs=100,
469
+ ... n_channels=3
470
+ ... )
471
+
472
+ If instead you want to train multiple channels together, you need to turn off the
473
+ `independent_channels` parameter:
474
+ >>> config = create_n2v_configuration(
475
+ ... experiment_name="n2v_experiment",
476
+ ... data_type="array",
477
+ ... axes="YXC",
478
+ ... patch_size=[64, 64],
479
+ ... batch_size=32,
480
+ ... num_epochs=100,
481
+ ... independent_channels=False,
482
+ ... n_channels=3
483
+ ... )
484
+
485
+ To turn off the augmentations, except normalization and N2V manipulation, use the
486
+ relevant keyword argument:
487
+ >>> config = create_n2v_configuration(
488
+ ... experiment_name="n2v_experiment",
489
+ ... data_type="array",
490
+ ... axes="YX",
491
+ ... patch_size=[64, 64],
492
+ ... batch_size=32,
493
+ ... num_epochs=100,
494
+ ... use_augmentations=False
495
+ ... )
496
+ """
497
+ # if there are channels, we need to specify their number
498
+ if "C" in axes and n_channels == 1:
499
+ raise ValueError(
500
+ f"Number of channels must be specified when using channels "
501
+ f"(got {n_channels} channel)."
502
+ )
503
+ elif "C" not in axes and n_channels > 1:
504
+ raise ValueError(
505
+ f"C is not present in the axes, but number of channels is specified "
506
+ f"(got {n_channels} channel)."
507
+ )
508
+
509
+ # model
510
+ if model_kwargs is None:
511
+ model_kwargs = {}
512
+ model_kwargs["n2v2"] = use_n2v2
513
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
514
+ model_kwargs["in_channels"] = n_channels
515
+ model_kwargs["num_classes"] = n_channels
516
+ model_kwargs["independent_channels"] = independent_channels
517
+
518
+ unet_model = UNetModel(
519
+ architecture=SupportedArchitecture.UNET.value,
520
+ **model_kwargs,
521
+ )
522
+
523
+ # algorithm model
524
+ algorithm = FCNAlgorithmConfig(
525
+ algorithm_type="fcn",
526
+ algorithm=SupportedAlgorithm.N2V.value,
527
+ loss=SupportedLoss.N2V.value,
528
+ model=unet_model,
529
+ )
530
+
531
+ # augmentations
532
+ if use_augmentations:
533
+ transforms: List[Dict[str, Any]] = [
534
+ {
535
+ "name": SupportedTransform.XY_FLIP.value,
536
+ },
537
+ {
538
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
539
+ },
540
+ ]
541
+ else:
542
+ transforms = []
543
+
544
+ # n2v2 and structn2v
545
+ nv2_transform = {
546
+ "name": SupportedTransform.N2V_MANIPULATE.value,
547
+ "strategy": (
548
+ SupportedPixelManipulation.MEDIAN.value
549
+ if use_n2v2
550
+ else SupportedPixelManipulation.UNIFORM.value
551
+ ),
552
+ "roi_size": roi_size,
553
+ "masked_pixel_percentage": masked_pixel_percentage,
554
+ "struct_mask_axis": struct_n2v_axis,
555
+ "struct_mask_span": struct_n2v_span,
556
+ }
557
+ transforms.append(nv2_transform)
558
+
559
+ # data model
560
+ data = DataConfig(
561
+ data_type=data_type,
562
+ axes=axes,
563
+ patch_size=patch_size,
564
+ batch_size=batch_size,
565
+ transforms=transforms,
566
+ )
567
+
568
+ # training model
569
+ training = TrainingConfig(
570
+ num_epochs=num_epochs,
571
+ batch_size=batch_size,
572
+ logger=None if logger == "none" else logger,
573
+ )
574
+
575
+ # create configuration
576
+ configuration = Configuration(
577
+ experiment_name=experiment_name,
578
+ algorithm_config=algorithm,
579
+ data_config=data,
580
+ training_config=training,
581
+ )
582
+
583
+ return configuration