careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,92 @@
1
+ """Checkpoint saving configuration."""
2
+ from __future__ import annotations
3
+
4
+ from datetime import timedelta
5
+ from typing import Literal, Optional
6
+
7
+ from pydantic import (
8
+ BaseModel,
9
+ ConfigDict,
10
+ Field,
11
+ )
12
+
13
+
14
+ class CheckpointModel(BaseModel):
15
+ """_summary_.
16
+
17
+ Parameters
18
+ ----------
19
+ BaseModel : _type_
20
+ _description_
21
+ """
22
+
23
+ model_config = ConfigDict(
24
+ validate_assignment=True,
25
+ )
26
+
27
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
28
+ verbose: bool = Field(default=False, validate_default=True)
29
+ save_weights_only: bool = Field(default=False, validate_default=True)
30
+ mode: Literal["min", "max"] = Field(default="min", validate_default=True)
31
+ auto_insert_metric_name: bool = Field(default=False, validate_default=True)
32
+ every_n_train_steps: Optional[int] = Field(
33
+ default=None, ge=1, le=10, validate_default=True
34
+ )
35
+ train_time_interval: Optional[timedelta] = Field(
36
+ default=None, validate_default=True
37
+ )
38
+ every_n_epochs: Optional[int] = Field(
39
+ default=None, ge=1, le=10, validate_default=True
40
+ )
41
+ save_last: Optional[Literal[True, False, "link"]] = Field(
42
+ default=True, validate_default=True
43
+ )
44
+ save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
45
+
46
+
47
+ class EarlyStoppingModel(BaseModel):
48
+ """_summary_.
49
+
50
+ Parameters
51
+ ----------
52
+ BaseModel : _type_
53
+ _description_
54
+ """
55
+
56
+ model_config = ConfigDict(
57
+ validate_assignment=True,
58
+ )
59
+
60
+ monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
61
+ patience: int = Field(default=3, ge=1, le=10, validate_default=True)
62
+ mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
63
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
64
+ check_finite: bool = Field(default=True, validate_default=True)
65
+ stop_on_nan: bool = Field(default=True, validate_default=True)
66
+ verbose: bool = Field(default=False, validate_default=True)
67
+ restore_best_weights: bool = Field(default=True, validate_default=True)
68
+ auto_lr_find: bool = Field(default=False, validate_default=True)
69
+ auto_lr_find_patience: int = Field(default=3, ge=1, le=10, validate_default=True)
70
+ auto_lr_find_mode: Literal["min", "max", "auto"] = Field(
71
+ default="min", validate_default=True
72
+ )
73
+ auto_lr_find_direction: Literal["forward", "backward"] = Field(
74
+ default="backward", validate_default=True
75
+ )
76
+ auto_lr_find_max_lr: float = Field(
77
+ default=10.0, ge=0.0, le=1e6, validate_default=True
78
+ )
79
+ auto_lr_find_min_lr: float = Field(
80
+ default=1e-8, ge=0.0, le=1e6, validate_default=True
81
+ )
82
+ auto_lr_find_num_training: int = Field(
83
+ default=100, ge=1, le=1e6, validate_default=True
84
+ )
85
+ auto_lr_find_divergence_threshold: float = Field(
86
+ default=5.0, ge=0.0, le=1e6, validate_default=True
87
+ )
88
+ auto_lr_find_accumulate_grad_batches: int = Field(
89
+ default=1, ge=1, le=1e6, validate_default=True
90
+ )
91
+ auto_lr_find_stop_divergence: bool = Field(default=True, validate_default=True)
92
+ auto_lr_find_step_scale: float = Field(default=0.1, ge=0.0, le=10)
@@ -0,0 +1,460 @@
1
+ """Convenience functions to create configurations for training and inference."""
2
+
3
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
+
5
+ from albumentations import Compose
6
+
7
+ from .algorithm_model import AlgorithmModel
8
+ from .architectures import UNetModel
9
+ from .configuration_model import Configuration
10
+ from .data_model import DataModel
11
+ from .inference_model import InferenceModel
12
+ from .support import (
13
+ SupportedAlgorithm,
14
+ SupportedArchitecture,
15
+ SupportedLoss,
16
+ SupportedPixelManipulation,
17
+ SupportedTransform,
18
+ )
19
+ from .training_model import TrainingModel
20
+
21
+
22
+ def create_n2n_configuration(
23
+ experiment_name: str,
24
+ data_type: Literal["array", "tiff", "custom"],
25
+ axes: str,
26
+ patch_size: List[int],
27
+ batch_size: int,
28
+ num_epochs: int,
29
+ use_augmentations: bool = True,
30
+ use_n2v2: bool = False,
31
+ n_channels: int = 1,
32
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
33
+ model_kwargs: Optional[dict] = None,
34
+ ) -> Configuration:
35
+ """
36
+ Create a configuration for training N2V.
37
+
38
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
39
+ 2.
40
+
41
+ By setting `use_augmentations` to False, the only transformation applied will be
42
+ normalization and N2V manipulation.
43
+
44
+ The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed
45
+ in `model_kwargs`.
46
+
47
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
48
+ will be applied to each manipulated pixel.
49
+
50
+ Parameters
51
+ ----------
52
+ experiment_name : str
53
+ Name of the experiment.
54
+ data_type : Literal["array", "tiff", "custom"]
55
+ Type of the data.
56
+ axes : str
57
+ Axes of the data (e.g. SYX).
58
+ patch_size : List[int]
59
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
60
+ batch_size : int
61
+ Batch size.
62
+ num_epochs : int
63
+ Number of epochs.
64
+ use_augmentations : bool, optional
65
+ Whether to use augmentations, by default True.
66
+ use_n2v2 : bool, optional
67
+ Whether to use N2V2, by default False.
68
+ n_channels : int, optional
69
+ Number of channels (in and out), by default 1.
70
+ roi_size : int, optional
71
+ N2V pixel manipulation area, by default 11.
72
+ masked_pixel_percentage : float, optional
73
+ Percentage of pixels masked in each patch, by default 0.2.
74
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
75
+ Axis along which to apply structN2V mask, by default "none".
76
+ struct_n2v_span : int, optional
77
+ Span of the structN2V mask, by default 5.
78
+ logger : Literal["wandb", "tensorboard", "none"], optional
79
+ Logger to use, by default "none".
80
+ model_kwargs : dict, optional
81
+ UNetModel parameters, by default {}.
82
+
83
+ Returns
84
+ -------
85
+ Configuration
86
+ Configuration for training N2V.
87
+ """
88
+ # model
89
+ if model_kwargs is None:
90
+ model_kwargs = {}
91
+ model_kwargs["n2v2"] = use_n2v2
92
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
93
+ model_kwargs["in_channels"] = n_channels
94
+ model_kwargs["num_classes"] = n_channels
95
+
96
+ unet_model = UNetModel(
97
+ architecture=SupportedArchitecture.UNET.value,
98
+ **model_kwargs,
99
+ )
100
+
101
+ # algorithm model
102
+ algorithm = AlgorithmModel(
103
+ algorithm=SupportedAlgorithm.N2V.value,
104
+ loss=SupportedLoss.N2V.value,
105
+ model=unet_model,
106
+ )
107
+
108
+ # augmentations
109
+ if use_augmentations:
110
+ transforms: List[Dict[str, Any]] = [
111
+ {
112
+ "name": SupportedTransform.NORMALIZE.value,
113
+ },
114
+ {
115
+ "name": SupportedTransform.NDFLIP.value,
116
+ },
117
+ {
118
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
119
+ },
120
+ ]
121
+ else:
122
+ transforms = [
123
+ {
124
+ "name": SupportedTransform.NORMALIZE.value,
125
+ },
126
+ ]
127
+
128
+ # data model
129
+ data = DataModel(
130
+ data_type=data_type,
131
+ axes=axes,
132
+ patch_size=patch_size,
133
+ batch_size=batch_size,
134
+ transforms=transforms,
135
+ )
136
+
137
+ # training model
138
+ training = TrainingModel(
139
+ num_epochs=num_epochs,
140
+ batch_size=batch_size,
141
+ logger=None if logger == "none" else logger,
142
+ )
143
+
144
+ # create configuration
145
+ configuration = Configuration(
146
+ experiment_name=experiment_name,
147
+ algorithm_config=algorithm,
148
+ data_config=data,
149
+ training_config=training,
150
+ )
151
+
152
+ return configuration
153
+
154
+
155
+ def create_n2v_configuration(
156
+ experiment_name: str,
157
+ data_type: Literal["array", "tiff", "custom"],
158
+ axes: str,
159
+ patch_size: List[int],
160
+ batch_size: int,
161
+ num_epochs: int,
162
+ use_augmentations: bool = True,
163
+ use_n2v2: bool = False,
164
+ n_channels: int = -1,
165
+ roi_size: int = 11,
166
+ masked_pixel_percentage: float = 0.2,
167
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
168
+ struct_n2v_span: int = 5,
169
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
170
+ model_kwargs: Optional[dict] = None,
171
+ ) -> Configuration:
172
+ """
173
+ Create a configuration for training N2V.
174
+
175
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
176
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
177
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
178
+
179
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
180
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
181
+ or horizontal correlations are present in the noise; it applies an additional mask
182
+ to the manipulated pixel neighbors.
183
+
184
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
185
+ channels.
186
+
187
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
188
+ 2.
189
+
190
+ By setting `use_augmentations` to False, the only transformations applied will be
191
+ normalization and N2V manipulation.
192
+
193
+ The `roi_size` parameter specifies the size of the area around each pixel that will
194
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
195
+ pixels per patch will be manipulated.
196
+
197
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
198
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
199
+ corresponding parameters passed in `model_kwargs`.
200
+
201
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
202
+ will be applied to each manipulated pixel.
203
+
204
+ Parameters
205
+ ----------
206
+ experiment_name : str
207
+ Name of the experiment.
208
+ data_type : Literal["array", "tiff", "custom"]
209
+ Type of the data.
210
+ axes : str
211
+ Axes of the data (e.g. SYX).
212
+ patch_size : List[int]
213
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
214
+ batch_size : int
215
+ Batch size.
216
+ num_epochs : int
217
+ Number of epochs.
218
+ use_augmentations : bool, optional
219
+ Whether to use augmentations, by default True.
220
+ use_n2v2 : bool, optional
221
+ Whether to use N2V2, by default False.
222
+ n_channels : int, optional
223
+ Number of channels (in and out), by default -1.
224
+ roi_size : int, optional
225
+ N2V pixel manipulation area, by default 11.
226
+ masked_pixel_percentage : float, optional
227
+ Percentage of pixels masked in each patch, by default 0.2.
228
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
229
+ Axis along which to apply structN2V mask, by default "none".
230
+ struct_n2v_span : int, optional
231
+ Span of the structN2V mask, by default 5.
232
+ logger : Literal["wandb", "tensorboard", "none"], optional
233
+ Logger to use, by default "none".
234
+ model_kwargs : dict, optional
235
+ UNetModel parameters, by default {}.
236
+
237
+ Returns
238
+ -------
239
+ Configuration
240
+ Configuration for training N2V.
241
+
242
+ Examples
243
+ --------
244
+ Minimum example:
245
+ >>> config = create_n2v_configuration(
246
+ ... experiment_name="n2v_experiment",
247
+ ... data_type="array",
248
+ ... axes="YX",
249
+ ... patch_size=[64, 64],
250
+ ... batch_size=32,
251
+ ... num_epochs=100
252
+ ... )
253
+
254
+ To use N2V2, simply pass the `use_n2v2` parameter:
255
+ >>> config = create_n2v_configuration(
256
+ ... experiment_name="n2v2_experiment",
257
+ ... data_type="tiff",
258
+ ... axes="YX",
259
+ ... patch_size=[64, 64],
260
+ ... batch_size=32,
261
+ ... num_epochs=100,
262
+ ... use_n2v2=True
263
+ ... )
264
+
265
+ For structN2V, there are two parameters to set, `struct_n2v_axis` and
266
+ `struct_n2v_span`:
267
+ >>> config = create_n2v_configuration(
268
+ ... experiment_name="structn2v_experiment",
269
+ ... data_type="tiff",
270
+ ... axes="YX",
271
+ ... patch_size=[64, 64],
272
+ ... batch_size=32,
273
+ ... num_epochs=100,
274
+ ... struct_n2v_axis="horizontal",
275
+ ... struct_n2v_span=7
276
+ ... )
277
+
278
+ If you are training multiple channels together, then you need to specify the number
279
+ of channels:
280
+ >>> config = create_n2v_configuration(
281
+ ... experiment_name="n2v_experiment",
282
+ ... data_type="array",
283
+ ... axes="YXC",
284
+ ... patch_size=[64, 64],
285
+ ... batch_size=32,
286
+ ... num_epochs=100,
287
+ ... n_channels=3
288
+ ... )
289
+
290
+ To turn off the augmentations, except normalization and N2V manipulation, use the
291
+ relevant keyword argument:
292
+ >>> config = create_n2v_configuration(
293
+ ... experiment_name="n2v_experiment",
294
+ ... data_type="array",
295
+ ... axes="YX",
296
+ ... patch_size=[64, 64],
297
+ ... batch_size=32,
298
+ ... num_epochs=100,
299
+ ... use_augmentations=False
300
+ ... )
301
+ """
302
+ # if there are channels, we need to specify their number
303
+ if "C" in axes and n_channels == -1:
304
+ raise ValueError(
305
+ f"Number of channels must be specified when using channels "
306
+ f"(got {n_channels} channel)."
307
+ )
308
+ elif "C" not in axes and n_channels != -1:
309
+ raise ValueError(
310
+ f"C is not present in the axes, but number of channels is specified "
311
+ f"(got {n_channels} channel)."
312
+ )
313
+ elif n_channels == -1:
314
+ n_channels = 1
315
+
316
+ # model
317
+ if model_kwargs is None:
318
+ model_kwargs = {}
319
+ model_kwargs["n2v2"] = use_n2v2
320
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
321
+ model_kwargs["in_channels"] = n_channels
322
+ model_kwargs["num_classes"] = n_channels
323
+
324
+ unet_model = UNetModel(
325
+ architecture=SupportedArchitecture.UNET.value,
326
+ **model_kwargs,
327
+ )
328
+
329
+ # algorithm model
330
+ algorithm = AlgorithmModel(
331
+ algorithm=SupportedAlgorithm.N2V.value,
332
+ loss=SupportedLoss.N2V.value,
333
+ model=unet_model,
334
+ )
335
+
336
+ # augmentations
337
+ if use_augmentations:
338
+ transforms: List[Dict[str, Any]] = [
339
+ {
340
+ "name": SupportedTransform.NORMALIZE.value,
341
+ },
342
+ {
343
+ "name": SupportedTransform.NDFLIP.value,
344
+ },
345
+ {
346
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
347
+ },
348
+ ]
349
+ else:
350
+ transforms = [
351
+ {
352
+ "name": SupportedTransform.NORMALIZE.value,
353
+ },
354
+ ]
355
+
356
+ # n2v2 and structn2v
357
+ nv2_transform = {
358
+ "name": SupportedTransform.N2V_MANIPULATE.value,
359
+ "strategy": SupportedPixelManipulation.MEDIAN.value
360
+ if use_n2v2
361
+ else SupportedPixelManipulation.UNIFORM.value,
362
+ "roi_size": roi_size,
363
+ "masked_pixel_percentage": masked_pixel_percentage,
364
+ "struct_mask_axis": struct_n2v_axis,
365
+ "struct_mask_span": struct_n2v_span,
366
+ }
367
+ transforms.append(nv2_transform)
368
+
369
+ # data model
370
+ data = DataModel(
371
+ data_type=data_type,
372
+ axes=axes,
373
+ patch_size=patch_size,
374
+ batch_size=batch_size,
375
+ transforms=transforms,
376
+ )
377
+
378
+ # training model
379
+ training = TrainingModel(
380
+ num_epochs=num_epochs,
381
+ batch_size=batch_size,
382
+ logger=None if logger == "none" else logger,
383
+ )
384
+
385
+ # create configuration
386
+ configuration = Configuration(
387
+ experiment_name=experiment_name,
388
+ algorithm_config=algorithm,
389
+ data_config=data,
390
+ training_config=training,
391
+ )
392
+
393
+ return configuration
394
+
395
+
396
+ # TODO add tests
397
+ def create_inference_configuration(
398
+ training_configuration: Configuration,
399
+ tile_size: Optional[Tuple[int, ...]] = None,
400
+ tile_overlap: Optional[Tuple[int, ...]] = None,
401
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
402
+ axes: Optional[str] = None,
403
+ transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
404
+ tta_transforms: bool = True,
405
+ batch_size: Optional[int] = 1,
406
+ ) -> InferenceModel:
407
+ """
408
+ Create a configuration for inference with N2V.
409
+
410
+ If not provided, `data_type` and `axes` are taken from the training
411
+ configuration. If `transforms` are not provided, only normalization is applied.
412
+
413
+ Parameters
414
+ ----------
415
+ training_configuration : Configuration
416
+ Configuration used for training.
417
+ tile_size : Tuple[int, ...], optional
418
+ Size of the tiles.
419
+ tile_overlap : Tuple[int, ...], optional
420
+ Overlap of the tiles.
421
+ data_type : str, optional
422
+ Type of the data, by default "tiff".
423
+ axes : str, optional
424
+ Axes of the data, by default "YX".
425
+ transforms : List[Dict[str, Any]] or Compose, optional
426
+ Transformations to apply to the data, by default None.
427
+ tta_transforms : bool, optional
428
+ Whether to apply test-time augmentations, by default True.
429
+ batch_size : int, optional
430
+ Batch size, by default 1.
431
+
432
+ Returns
433
+ -------
434
+ InferenceConfiguration
435
+ Configuration for inference with N2V.
436
+ """
437
+ if (
438
+ training_configuration.data_config.mean is None
439
+ or training_configuration.data_config.std is None
440
+ ):
441
+ raise ValueError("Mean and std must be provided in the training configuration.")
442
+
443
+ if transforms is None:
444
+ transforms = [
445
+ {
446
+ "name": SupportedTransform.NORMALIZE.value,
447
+ },
448
+ ]
449
+
450
+ return InferenceModel(
451
+ data_type=data_type or training_configuration.data_config.data_type,
452
+ tile_size=tile_size,
453
+ tile_overlap=tile_overlap,
454
+ axes=axes or training_configuration.data_config.axes,
455
+ mean=training_configuration.data_config.mean,
456
+ std=training_configuration.data_config.std,
457
+ transforms=transforms,
458
+ tta_transforms=tta_transforms,
459
+ batch_size=batch_size,
460
+ )