careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,123 @@
1
+ """Callback Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import timedelta
6
+ from typing import Literal, Optional
7
+
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ )
13
+
14
+
15
+ class CheckpointModel(BaseModel):
16
+ """Checkpoint saving callback Pydantic model.
17
+
18
+ The parameters corresponds to those of
19
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
20
+
21
+ See:
22
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
+ """
24
+
25
+ model_config = ConfigDict(
26
+ validate_assignment=True,
27
+ )
28
+
29
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
+ """Quantity to monitor."""
31
+
32
+ verbose: bool = Field(default=False, validate_default=True)
33
+ """Verbosity mode."""
34
+
35
+ save_weights_only: bool = Field(default=False, validate_default=True)
36
+ """When `True`, only the model's weights will be saved (model.save_weights)."""
37
+
38
+ save_last: Optional[Literal[True, False, "link"]] = Field(
39
+ default=True, validate_default=True
40
+ )
41
+ """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
+
43
+ save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
44
+ """If `save_top_k == kz, the best k models according to the quantity monitored
45
+ will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
+ all models are saved."""
47
+
48
+ mode: Literal["min", "max"] = Field(default="min", validate_default=True)
49
+ """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
+ save file is made based on either the maximization or the minimization of the
51
+ monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
+ be 'min', etc.
53
+ """
54
+
55
+ auto_insert_metric_name: bool = Field(default=False, validate_default=True)
56
+ """When `True`, the checkpoints filenames will contain the metric name."""
57
+
58
+ every_n_train_steps: Optional[int] = Field(
59
+ default=None, ge=1, le=10, validate_default=True
60
+ )
61
+ """Number of training steps between checkpoints."""
62
+
63
+ train_time_interval: Optional[timedelta] = Field(
64
+ default=None, validate_default=True
65
+ )
66
+ """Checkpoints are monitored at the specified time interval."""
67
+
68
+ every_n_epochs: Optional[int] = Field(
69
+ default=None, ge=1, le=10, validate_default=True
70
+ )
71
+ """Number of epochs between checkpoints."""
72
+
73
+
74
+ class EarlyStoppingModel(BaseModel):
75
+ """Early stopping callback Pydantic model.
76
+
77
+ The parameters corresponds to those of
78
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
79
+
80
+ See:
81
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
82
+ """
83
+
84
+ model_config = ConfigDict(
85
+ validate_assignment=True,
86
+ )
87
+
88
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
89
+ """Quantity to monitor."""
90
+
91
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
92
+ """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
+ absolute change of less than or equal to min_delta, will count as no improvement."""
94
+
95
+ patience: int = Field(default=3, ge=1, le=10, validate_default=True)
96
+ """Number of checks with no improvement after which training will be stopped."""
97
+
98
+ verbose: bool = Field(default=False, validate_default=True)
99
+ """Verbosity mode."""
100
+
101
+ mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
102
+ """One of {min, max, auto}."""
103
+
104
+ check_finite: bool = Field(default=True, validate_default=True)
105
+ """When `True`, stops training when the monitored quantity becomes `NaN` or
106
+ `inf`."""
107
+
108
+ stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
109
+ """Stop training immediately once the monitored quantity reaches this threshold."""
110
+
111
+ divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
112
+ """Stop training as soon as the monitored quantity becomes worse than this
113
+ threshold."""
114
+
115
+ check_on_train_epoch_end: Optional[bool] = Field(
116
+ default=False, validate_default=True
117
+ )
118
+ """Whether to run early stopping at the end of the training epoch. If this is
119
+ `False`, then the check runs at the end of the validation."""
120
+
121
+ log_rank_zero_only: bool = Field(default=False, validate_default=True)
122
+ """When set `True`, logs the status of the early stopping callback only for rank 0
123
+ process."""
@@ -0,0 +1,575 @@
1
+ """Convenience functions to create configurations for training and inference."""
2
+
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ from .algorithm_model import AlgorithmConfig
6
+ from .architectures import UNetModel
7
+ from .configuration_model import Configuration
8
+ from .data_model import DataConfig
9
+ from .support import (
10
+ SupportedAlgorithm,
11
+ SupportedArchitecture,
12
+ SupportedLoss,
13
+ SupportedPixelManipulation,
14
+ SupportedTransform,
15
+ )
16
+ from .training_model import TrainingConfig
17
+
18
+
19
+ def _create_supervised_configuration(
20
+ algorithm: Literal["care", "n2n"],
21
+ experiment_name: str,
22
+ data_type: Literal["array", "tiff", "custom"],
23
+ axes: str,
24
+ patch_size: List[int],
25
+ batch_size: int,
26
+ num_epochs: int,
27
+ use_augmentations: bool = True,
28
+ independent_channels: bool = False,
29
+ loss: Literal["mae", "mse"] = "mae",
30
+ n_channels_in: int = 1,
31
+ n_channels_out: int = 1,
32
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
33
+ model_kwargs: Optional[dict] = None,
34
+ ) -> Configuration:
35
+ """
36
+ Create a configuration for training CARE or Noise2Noise.
37
+
38
+ Parameters
39
+ ----------
40
+ algorithm : Literal["care", "n2n"]
41
+ Algorithm to use.
42
+ experiment_name : str
43
+ Name of the experiment.
44
+ data_type : Literal["array", "tiff", "custom"]
45
+ Type of the data.
46
+ axes : str
47
+ Axes of the data (e.g. SYX).
48
+ patch_size : List[int]
49
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
50
+ batch_size : int
51
+ Batch size.
52
+ num_epochs : int
53
+ Number of epochs.
54
+ use_augmentations : bool, optional
55
+ Whether to use augmentations, by default True.
56
+ independent_channels : bool, optional
57
+ Whether to train all channels independently, by default False.
58
+ loss : Literal["mae", "mse"], optional
59
+ Loss function to use, by default "mae".
60
+ n_channels_in : int, optional
61
+ Number of channels in, by default 1.
62
+ n_channels_out : int, optional
63
+ Number of channels out, by default 1.
64
+ logger : Literal["wandb", "tensorboard", "none"], optional
65
+ Logger to use, by default "none".
66
+ model_kwargs : dict, optional
67
+ UNetModel parameters, by default {}.
68
+
69
+ Returns
70
+ -------
71
+ Configuration
72
+ Configuration for training CARE or Noise2Noise.
73
+ """
74
+ # if there are channels, we need to specify their number
75
+ if "C" in axes and n_channels_in == 1:
76
+ raise ValueError(
77
+ f"Number of channels in must be specified when using channels "
78
+ f"(got {n_channels_in} channel)."
79
+ )
80
+ elif "C" not in axes and n_channels_in > 1:
81
+ raise ValueError(
82
+ f"C is not present in the axes, but number of channels is specified "
83
+ f"(got {n_channels_in} channels)."
84
+ )
85
+
86
+ # model
87
+ if model_kwargs is None:
88
+ model_kwargs = {}
89
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
90
+ model_kwargs["in_channels"] = n_channels_in
91
+ model_kwargs["num_classes"] = n_channels_out
92
+ model_kwargs["independent_channels"] = independent_channels
93
+
94
+ unet_model = UNetModel(
95
+ architecture=SupportedArchitecture.UNET.value,
96
+ **model_kwargs,
97
+ )
98
+
99
+ # algorithm model
100
+ algorithm = AlgorithmConfig(
101
+ algorithm=algorithm,
102
+ loss=loss,
103
+ model=unet_model,
104
+ )
105
+
106
+ # augmentations
107
+ if use_augmentations:
108
+ transforms: List[Dict[str, Any]] = [
109
+ {
110
+ "name": SupportedTransform.XY_FLIP.value,
111
+ },
112
+ {
113
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
114
+ },
115
+ ]
116
+ else:
117
+ transforms = []
118
+
119
+ # data model
120
+ data = DataConfig(
121
+ data_type=data_type,
122
+ axes=axes,
123
+ patch_size=patch_size,
124
+ batch_size=batch_size,
125
+ transforms=transforms,
126
+ )
127
+
128
+ # training model
129
+ training = TrainingConfig(
130
+ num_epochs=num_epochs,
131
+ batch_size=batch_size,
132
+ logger=None if logger == "none" else logger,
133
+ )
134
+
135
+ # create configuration
136
+ configuration = Configuration(
137
+ experiment_name=experiment_name,
138
+ algorithm_config=algorithm,
139
+ data_config=data,
140
+ training_config=training,
141
+ )
142
+
143
+ return configuration
144
+
145
+
146
+ def create_care_configuration(
147
+ experiment_name: str,
148
+ data_type: Literal["array", "tiff", "custom"],
149
+ axes: str,
150
+ patch_size: List[int],
151
+ batch_size: int,
152
+ num_epochs: int,
153
+ use_augmentations: bool = True,
154
+ independent_channels: bool = False,
155
+ loss: Literal["mae", "mse"] = "mae",
156
+ n_channels_in: int = 1,
157
+ n_channels_out: int = -1,
158
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
159
+ model_kwargs: Optional[dict] = None,
160
+ ) -> Configuration:
161
+ """
162
+ Create a configuration for training CARE.
163
+
164
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
165
+ 2.
166
+
167
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
168
+ channels. Likewise, if you set the number of channels, then "C" must be present in
169
+ `axes`.
170
+
171
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
172
+ not specified, it will be assumed to be equal to `n_channels_in`.
173
+
174
+ By default, all channels are trained together. To train all channels independently,
175
+ set `independent_channels` to True.
176
+
177
+ By setting `use_augmentations` to False, the only transformation applied will be
178
+ normalization.
179
+
180
+ Parameters
181
+ ----------
182
+ experiment_name : str
183
+ Name of the experiment.
184
+ data_type : Literal["array", "tiff", "custom"]
185
+ Type of the data.
186
+ axes : str
187
+ Axes of the data (e.g. SYX).
188
+ patch_size : List[int]
189
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
190
+ batch_size : int
191
+ Batch size.
192
+ num_epochs : int
193
+ Number of epochs.
194
+ use_augmentations : bool, optional
195
+ Whether to use augmentations, by default True.
196
+ independent_channels : bool, optional
197
+ Whether to train all channels independently, by default False.
198
+ loss : Literal["mae", "mse"], optional
199
+ Loss function to use, by default "mae".
200
+ n_channels_in : int, optional
201
+ Number of channels in, by default 1.
202
+ n_channels_out : int, optional
203
+ Number of channels out, by default -1.
204
+ logger : Literal["wandb", "tensorboard", "none"], optional
205
+ Logger to use, by default "none".
206
+ model_kwargs : dict, optional
207
+ UNetModel parameters, by default {}.
208
+
209
+ Returns
210
+ -------
211
+ Configuration
212
+ Configuration for training CARE.
213
+ """
214
+ if n_channels_out == -1:
215
+ n_channels_out = n_channels_in
216
+
217
+ return _create_supervised_configuration(
218
+ algorithm="care",
219
+ experiment_name=experiment_name,
220
+ data_type=data_type,
221
+ axes=axes,
222
+ patch_size=patch_size,
223
+ batch_size=batch_size,
224
+ num_epochs=num_epochs,
225
+ use_augmentations=use_augmentations,
226
+ independent_channels=independent_channels,
227
+ loss=loss,
228
+ n_channels_in=n_channels_in,
229
+ n_channels_out=n_channels_out,
230
+ logger=logger,
231
+ model_kwargs=model_kwargs,
232
+ )
233
+
234
+
235
+ def create_n2n_configuration(
236
+ experiment_name: str,
237
+ data_type: Literal["array", "tiff", "custom"],
238
+ axes: str,
239
+ patch_size: List[int],
240
+ batch_size: int,
241
+ num_epochs: int,
242
+ use_augmentations: bool = True,
243
+ independent_channels: bool = False,
244
+ loss: Literal["mae", "mse"] = "mae",
245
+ n_channels_in: int = 1,
246
+ n_channels_out: int = -1,
247
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
248
+ model_kwargs: Optional[dict] = None,
249
+ ) -> Configuration:
250
+ """
251
+ Create a configuration for training Noise2Noise.
252
+
253
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
254
+ 2.
255
+
256
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
257
+ channels. Likewise, if you set the number of channels, then "C" must be present in
258
+ `axes`.
259
+
260
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
261
+ not specified, it will be assumed to be equal to `n_channels_in`.
262
+
263
+ By default, all channels are trained together. To train all channels independently,
264
+ set `independent_channels` to True.
265
+
266
+ By setting `use_augmentations` to False, the only transformation applied will be
267
+ normalization.
268
+
269
+ Parameters
270
+ ----------
271
+ experiment_name : str
272
+ Name of the experiment.
273
+ data_type : Literal["array", "tiff", "custom"]
274
+ Type of the data.
275
+ axes : str
276
+ Axes of the data (e.g. SYX).
277
+ patch_size : List[int]
278
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
279
+ batch_size : int
280
+ Batch size.
281
+ num_epochs : int
282
+ Number of epochs.
283
+ use_augmentations : bool, optional
284
+ Whether to use augmentations, by default True.
285
+ independent_channels : bool, optional
286
+ Whether to train all channels independently, by default False.
287
+ loss : Literal["mae", "mse"], optional
288
+ Loss function to use, by default "mae".
289
+ n_channels_in : int, optional
290
+ Number of channels in, by default 1.
291
+ n_channels_out : int, optional
292
+ Number of channels out, by default -1.
293
+ logger : Literal["wandb", "tensorboard", "none"], optional
294
+ Logger to use, by default "none".
295
+ model_kwargs : dict, optional
296
+ UNetModel parameters, by default {}.
297
+
298
+ Returns
299
+ -------
300
+ Configuration
301
+ Configuration for training Noise2Noise.
302
+ """
303
+ if n_channels_out == -1:
304
+ n_channels_out = n_channels_in
305
+
306
+ return _create_supervised_configuration(
307
+ algorithm="n2n",
308
+ experiment_name=experiment_name,
309
+ data_type=data_type,
310
+ axes=axes,
311
+ patch_size=patch_size,
312
+ batch_size=batch_size,
313
+ num_epochs=num_epochs,
314
+ use_augmentations=use_augmentations,
315
+ independent_channels=independent_channels,
316
+ loss=loss,
317
+ n_channels_in=n_channels_in,
318
+ n_channels_out=n_channels_out,
319
+ logger=logger,
320
+ model_kwargs=model_kwargs,
321
+ )
322
+
323
+
324
+ def create_n2v_configuration(
325
+ experiment_name: str,
326
+ data_type: Literal["array", "tiff", "custom"],
327
+ axes: str,
328
+ patch_size: List[int],
329
+ batch_size: int,
330
+ num_epochs: int,
331
+ use_augmentations: bool = True,
332
+ independent_channels: bool = True,
333
+ use_n2v2: bool = False,
334
+ n_channels: int = 1,
335
+ roi_size: int = 11,
336
+ masked_pixel_percentage: float = 0.2,
337
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
338
+ struct_n2v_span: int = 5,
339
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
340
+ model_kwargs: Optional[dict] = None,
341
+ ) -> Configuration:
342
+ """
343
+ Create a configuration for training Noise2Void.
344
+
345
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
346
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
347
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
348
+
349
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
350
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
351
+ or horizontal correlations are present in the noise; it applies an additional mask
352
+ to the manipulated pixel neighbors.
353
+
354
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
355
+ 2.
356
+
357
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
358
+ channels.
359
+
360
+ By default, all channels are trained independently. To train all channels together,
361
+ set `independent_channels` to False.
362
+
363
+ By setting `use_augmentations` to False, the only transformations applied will be
364
+ normalization and N2V manipulation.
365
+
366
+ The `roi_size` parameter specifies the size of the area around each pixel that will
367
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
368
+ pixels per patch will be manipulated.
369
+
370
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
371
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
372
+ corresponding parameters passed in `model_kwargs`.
373
+
374
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
375
+ will be applied to each manipulated pixel.
376
+
377
+ Parameters
378
+ ----------
379
+ experiment_name : str
380
+ Name of the experiment.
381
+ data_type : Literal["array", "tiff", "custom"]
382
+ Type of the data.
383
+ axes : str
384
+ Axes of the data (e.g. SYX).
385
+ patch_size : List[int]
386
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
387
+ batch_size : int
388
+ Batch size.
389
+ num_epochs : int
390
+ Number of epochs.
391
+ use_augmentations : bool, optional
392
+ Whether to use augmentations, by default True.
393
+ independent_channels : bool, optional
394
+ Whether to train all channels together, by default True.
395
+ use_n2v2 : bool, optional
396
+ Whether to use N2V2, by default False.
397
+ n_channels : int, optional
398
+ Number of channels (in and out), by default 1.
399
+ roi_size : int, optional
400
+ N2V pixel manipulation area, by default 11.
401
+ masked_pixel_percentage : float, optional
402
+ Percentage of pixels masked in each patch, by default 0.2.
403
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
404
+ Axis along which to apply structN2V mask, by default "none".
405
+ struct_n2v_span : int, optional
406
+ Span of the structN2V mask, by default 5.
407
+ logger : Literal["wandb", "tensorboard", "none"], optional
408
+ Logger to use, by default "none".
409
+ model_kwargs : dict, optional
410
+ UNetModel parameters, by default {}.
411
+
412
+ Returns
413
+ -------
414
+ Configuration
415
+ Configuration for training N2V.
416
+
417
+ Examples
418
+ --------
419
+ Minimum example:
420
+ >>> config = create_n2v_configuration(
421
+ ... experiment_name="n2v_experiment",
422
+ ... data_type="array",
423
+ ... axes="YX",
424
+ ... patch_size=[64, 64],
425
+ ... batch_size=32,
426
+ ... num_epochs=100
427
+ ... )
428
+
429
+ To use N2V2, simply pass the `use_n2v2` parameter:
430
+ >>> config = create_n2v_configuration(
431
+ ... experiment_name="n2v2_experiment",
432
+ ... data_type="tiff",
433
+ ... axes="YX",
434
+ ... patch_size=[64, 64],
435
+ ... batch_size=32,
436
+ ... num_epochs=100,
437
+ ... use_n2v2=True
438
+ ... )
439
+
440
+ For structN2V, there are two parameters to set, `struct_n2v_axis` and
441
+ `struct_n2v_span`:
442
+ >>> config = create_n2v_configuration(
443
+ ... experiment_name="structn2v_experiment",
444
+ ... data_type="tiff",
445
+ ... axes="YX",
446
+ ... patch_size=[64, 64],
447
+ ... batch_size=32,
448
+ ... num_epochs=100,
449
+ ... struct_n2v_axis="horizontal",
450
+ ... struct_n2v_span=7
451
+ ... )
452
+
453
+ If you are training multiple channels independently, then you need to specify the
454
+ number of channels:
455
+ >>> config = create_n2v_configuration(
456
+ ... experiment_name="n2v_experiment",
457
+ ... data_type="array",
458
+ ... axes="YXC",
459
+ ... patch_size=[64, 64],
460
+ ... batch_size=32,
461
+ ... num_epochs=100,
462
+ ... n_channels=3
463
+ ... )
464
+
465
+ If instead you want to train multiple channels together, you need to turn off the
466
+ `independent_channels` parameter:
467
+ >>> config = create_n2v_configuration(
468
+ ... experiment_name="n2v_experiment",
469
+ ... data_type="array",
470
+ ... axes="YXC",
471
+ ... patch_size=[64, 64],
472
+ ... batch_size=32,
473
+ ... num_epochs=100,
474
+ ... independent_channels=False,
475
+ ... n_channels=3
476
+ ... )
477
+
478
+ To turn off the augmentations, except normalization and N2V manipulation, use the
479
+ relevant keyword argument:
480
+ >>> config = create_n2v_configuration(
481
+ ... experiment_name="n2v_experiment",
482
+ ... data_type="array",
483
+ ... axes="YX",
484
+ ... patch_size=[64, 64],
485
+ ... batch_size=32,
486
+ ... num_epochs=100,
487
+ ... use_augmentations=False
488
+ ... )
489
+ """
490
+ # if there are channels, we need to specify their number
491
+ if "C" in axes and n_channels == 1:
492
+ raise ValueError(
493
+ f"Number of channels must be specified when using channels "
494
+ f"(got {n_channels} channel)."
495
+ )
496
+ elif "C" not in axes and n_channels > 1:
497
+ raise ValueError(
498
+ f"C is not present in the axes, but number of channels is specified "
499
+ f"(got {n_channels} channel)."
500
+ )
501
+
502
+ # model
503
+ if model_kwargs is None:
504
+ model_kwargs = {}
505
+ model_kwargs["n2v2"] = use_n2v2
506
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
507
+ model_kwargs["in_channels"] = n_channels
508
+ model_kwargs["num_classes"] = n_channels
509
+ model_kwargs["independent_channels"] = independent_channels
510
+
511
+ unet_model = UNetModel(
512
+ architecture=SupportedArchitecture.UNET.value,
513
+ **model_kwargs,
514
+ )
515
+
516
+ # algorithm model
517
+ algorithm = AlgorithmConfig(
518
+ algorithm=SupportedAlgorithm.N2V.value,
519
+ loss=SupportedLoss.N2V.value,
520
+ model=unet_model,
521
+ )
522
+
523
+ # augmentations
524
+ if use_augmentations:
525
+ transforms: List[Dict[str, Any]] = [
526
+ {
527
+ "name": SupportedTransform.XY_FLIP.value,
528
+ },
529
+ {
530
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
531
+ },
532
+ ]
533
+ else:
534
+ transforms = []
535
+
536
+ # n2v2 and structn2v
537
+ nv2_transform = {
538
+ "name": SupportedTransform.N2V_MANIPULATE.value,
539
+ "strategy": (
540
+ SupportedPixelManipulation.MEDIAN.value
541
+ if use_n2v2
542
+ else SupportedPixelManipulation.UNIFORM.value
543
+ ),
544
+ "roi_size": roi_size,
545
+ "masked_pixel_percentage": masked_pixel_percentage,
546
+ "struct_mask_axis": struct_n2v_axis,
547
+ "struct_mask_span": struct_n2v_span,
548
+ }
549
+ transforms.append(nv2_transform)
550
+
551
+ # data model
552
+ data = DataConfig(
553
+ data_type=data_type,
554
+ axes=axes,
555
+ patch_size=patch_size,
556
+ batch_size=batch_size,
557
+ transforms=transforms,
558
+ )
559
+
560
+ # training model
561
+ training = TrainingConfig(
562
+ num_epochs=num_epochs,
563
+ batch_size=batch_size,
564
+ logger=None if logger == "none" else logger,
565
+ )
566
+
567
+ # create configuration
568
+ configuration = Configuration(
569
+ experiment_name=experiment_name,
570
+ algorithm_config=algorithm,
571
+ data_config=data,
572
+ training_config=training,
573
+ )
574
+
575
+ return configuration