careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,597 @@
1
+ """Convenience functions to create configurations for training and inference."""
2
+
3
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
+
5
+ from albumentations import Compose
6
+
7
+ from .algorithm_model import AlgorithmConfig
8
+ from .architectures import UNetModel
9
+ from .configuration_model import Configuration
10
+ from .data_model import DataConfig
11
+ from .inference_model import InferenceConfig
12
+ from .support import (
13
+ SupportedAlgorithm,
14
+ SupportedArchitecture,
15
+ SupportedLoss,
16
+ SupportedPixelManipulation,
17
+ SupportedTransform,
18
+ )
19
+ from .training_model import TrainingConfig
20
+
21
+
22
+ def _create_supervised_configuration(
23
+ algorithm: Literal["care", "n2n"],
24
+ experiment_name: str,
25
+ data_type: Literal["array", "tiff", "custom"],
26
+ axes: str,
27
+ patch_size: List[int],
28
+ batch_size: int,
29
+ num_epochs: int,
30
+ use_augmentations: bool = True,
31
+ loss: Literal["mae", "mse"] = "mae",
32
+ n_channels: int = -1,
33
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
34
+ model_kwargs: Optional[dict] = None,
35
+ ) -> Configuration:
36
+ """
37
+ Create a configuration for training CARE or Noise2Noise.
38
+
39
+ Parameters
40
+ ----------
41
+ algorithm : Literal["care", "n2n"]
42
+ Algorithm to use.
43
+ experiment_name : str
44
+ Name of the experiment.
45
+ data_type : Literal["array", "tiff", "custom"]
46
+ Type of the data.
47
+ axes : str
48
+ Axes of the data (e.g. SYX).
49
+ patch_size : List[int]
50
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
51
+ batch_size : int
52
+ Batch size.
53
+ num_epochs : int
54
+ Number of epochs.
55
+ use_augmentations : bool, optional
56
+ Whether to use augmentations, by default True.
57
+ loss : Literal["mae", "mse"], optional
58
+ Loss function to use, by default "mae".
59
+ n_channels : int, optional
60
+ Number of channels (in and out), by default -1.
61
+ logger : Literal["wandb", "tensorboard", "none"], optional
62
+ Logger to use, by default "none".
63
+ model_kwargs : dict, optional
64
+ UNetModel parameters, by default {}.
65
+
66
+ Returns
67
+ -------
68
+ Configuration
69
+ Configuration for training CARE or Noise2Noise.
70
+ """
71
+ # if there are channels, we need to specify their number
72
+ if "C" in axes and n_channels == 1:
73
+ raise ValueError(
74
+ f"Number of channels must be specified when using channels "
75
+ f"(got {n_channels} channel)."
76
+ )
77
+ elif "C" not in axes and n_channels > 1:
78
+ raise ValueError(
79
+ f"C is not present in the axes, but number of channels is specified "
80
+ f"(got {n_channels} channel)."
81
+ )
82
+
83
+ # model
84
+ if model_kwargs is None:
85
+ model_kwargs = {}
86
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
87
+ model_kwargs["in_channels"] = n_channels
88
+ model_kwargs["num_classes"] = n_channels
89
+
90
+ unet_model = UNetModel(
91
+ architecture=SupportedArchitecture.UNET.value,
92
+ **model_kwargs,
93
+ )
94
+
95
+ # algorithm model
96
+ algorithm = AlgorithmConfig(
97
+ algorithm=algorithm,
98
+ loss=loss,
99
+ model=unet_model,
100
+ )
101
+
102
+ # augmentations
103
+ if use_augmentations:
104
+ transforms: List[Dict[str, Any]] = [
105
+ {
106
+ "name": SupportedTransform.NORMALIZE.value,
107
+ },
108
+ {
109
+ "name": SupportedTransform.NDFLIP.value,
110
+ },
111
+ {
112
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
113
+ },
114
+ ]
115
+ else:
116
+ transforms = [
117
+ {
118
+ "name": SupportedTransform.NORMALIZE.value,
119
+ },
120
+ ]
121
+
122
+ # data model
123
+ data = DataConfig(
124
+ data_type=data_type,
125
+ axes=axes,
126
+ patch_size=patch_size,
127
+ batch_size=batch_size,
128
+ transforms=transforms,
129
+ )
130
+
131
+ # training model
132
+ training = TrainingConfig(
133
+ num_epochs=num_epochs,
134
+ batch_size=batch_size,
135
+ logger=None if logger == "none" else logger,
136
+ )
137
+
138
+ # create configuration
139
+ configuration = Configuration(
140
+ experiment_name=experiment_name,
141
+ algorithm_config=algorithm,
142
+ data_config=data,
143
+ training_config=training,
144
+ )
145
+
146
+ return configuration
147
+
148
+
149
+ def create_care_configuration(
150
+ experiment_name: str,
151
+ data_type: Literal["array", "tiff", "custom"],
152
+ axes: str,
153
+ patch_size: List[int],
154
+ batch_size: int,
155
+ num_epochs: int,
156
+ use_augmentations: bool = True,
157
+ loss: Literal["mae", "mse"] = "mae",
158
+ n_channels: int = 1,
159
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
160
+ model_kwargs: Optional[dict] = None,
161
+ ) -> Configuration:
162
+ """
163
+ Create a configuration for training CARE.
164
+
165
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
166
+ 2.
167
+
168
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
169
+ channels. Likewise, if you set the number of channels, then "C" must be present in
170
+ `axes`.
171
+
172
+ By setting `use_augmentations` to False, the only transformation applied will be
173
+ normalization.
174
+
175
+ Parameters
176
+ ----------
177
+ experiment_name : str
178
+ Name of the experiment.
179
+ data_type : Literal["array", "tiff", "custom"]
180
+ Type of the data.
181
+ axes : str
182
+ Axes of the data (e.g. SYX).
183
+ patch_size : List[int]
184
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
185
+ batch_size : int
186
+ Batch size.
187
+ num_epochs : int
188
+ Number of epochs.
189
+ use_augmentations : bool, optional
190
+ Whether to use augmentations, by default True.
191
+ loss : Literal["mae", "mse"], optional
192
+ Loss function to use, by default "mae".
193
+ n_channels : int, optional
194
+ Number of channels (in and out), by default 1.
195
+ logger : Literal["wandb", "tensorboard", "none"], optional
196
+ Logger to use, by default "none".
197
+ model_kwargs : dict, optional
198
+ UNetModel parameters, by default {}.
199
+
200
+ Returns
201
+ -------
202
+ Configuration
203
+ Configuration for training CARE.
204
+ """
205
+ return _create_supervised_configuration(
206
+ algorithm="care",
207
+ experiment_name=experiment_name,
208
+ data_type=data_type,
209
+ axes=axes,
210
+ patch_size=patch_size,
211
+ batch_size=batch_size,
212
+ num_epochs=num_epochs,
213
+ use_augmentations=use_augmentations,
214
+ loss=loss,
215
+ # TODO in the future we might support different in and out channels for CARE
216
+ n_channels=n_channels,
217
+ logger=logger,
218
+ model_kwargs=model_kwargs,
219
+ )
220
+
221
+
222
+ def create_n2n_configuration(
223
+ experiment_name: str,
224
+ data_type: Literal["array", "tiff", "custom"],
225
+ axes: str,
226
+ patch_size: List[int],
227
+ batch_size: int,
228
+ num_epochs: int,
229
+ use_augmentations: bool = True,
230
+ loss: Literal["mae", "mse"] = "mae",
231
+ n_channels: int = 1,
232
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
233
+ model_kwargs: Optional[dict] = None,
234
+ ) -> Configuration:
235
+ """
236
+ Create a configuration for training Noise2Noise.
237
+
238
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
239
+ 2.
240
+
241
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
242
+ channels. Likewise, if you set the number of channels, then "C" must be present in
243
+ `axes`.
244
+
245
+ By setting `use_augmentations` to False, the only transformation applied will be
246
+ normalization.
247
+
248
+ Parameters
249
+ ----------
250
+ experiment_name : str
251
+ Name of the experiment.
252
+ data_type : Literal["array", "tiff", "custom"]
253
+ Type of the data.
254
+ axes : str
255
+ Axes of the data (e.g. SYX).
256
+ patch_size : List[int]
257
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
258
+ batch_size : int
259
+ Batch size.
260
+ num_epochs : int
261
+ Number of epochs.
262
+ use_augmentations : bool, optional
263
+ Whether to use augmentations, by default True.
264
+ loss : Literal["mae", "mse"], optional
265
+ Loss function to use, by default "mae".
266
+ n_channels : int, optional
267
+ Number of channels (in and out), by default 1.
268
+ logger : Literal["wandb", "tensorboard", "none"], optional
269
+ Logger to use, by default "none".
270
+ model_kwargs : dict, optional
271
+ UNetModel parameters, by default {}.
272
+
273
+ Returns
274
+ -------
275
+ Configuration
276
+ Configuration for training Noise2Noise.
277
+ """
278
+ return _create_supervised_configuration(
279
+ algorithm="n2n",
280
+ experiment_name=experiment_name,
281
+ data_type=data_type,
282
+ axes=axes,
283
+ patch_size=patch_size,
284
+ batch_size=batch_size,
285
+ num_epochs=num_epochs,
286
+ use_augmentations=use_augmentations,
287
+ loss=loss,
288
+ n_channels=n_channels,
289
+ logger=logger,
290
+ model_kwargs=model_kwargs,
291
+ )
292
+
293
+
294
+ def create_n2v_configuration(
295
+ experiment_name: str,
296
+ data_type: Literal["array", "tiff", "custom"],
297
+ axes: str,
298
+ patch_size: List[int],
299
+ batch_size: int,
300
+ num_epochs: int,
301
+ use_augmentations: bool = True,
302
+ use_n2v2: bool = False,
303
+ n_channels: int = 1,
304
+ roi_size: int = 11,
305
+ masked_pixel_percentage: float = 0.2,
306
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
307
+ struct_n2v_span: int = 5,
308
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
309
+ model_kwargs: Optional[dict] = None,
310
+ ) -> Configuration:
311
+ """
312
+ Create a configuration for training Noise2Void.
313
+
314
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
315
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
316
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
317
+
318
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
319
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
320
+ or horizontal correlations are present in the noise; it applies an additional mask
321
+ to the manipulated pixel neighbors.
322
+
323
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
324
+ channels.
325
+
326
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
327
+ 2.
328
+
329
+ By setting `use_augmentations` to False, the only transformations applied will be
330
+ normalization and N2V manipulation.
331
+
332
+ The `roi_size` parameter specifies the size of the area around each pixel that will
333
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
334
+ pixels per patch will be manipulated.
335
+
336
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
337
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
338
+ corresponding parameters passed in `model_kwargs`.
339
+
340
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
341
+ will be applied to each manipulated pixel.
342
+
343
+ Parameters
344
+ ----------
345
+ experiment_name : str
346
+ Name of the experiment.
347
+ data_type : Literal["array", "tiff", "custom"]
348
+ Type of the data.
349
+ axes : str
350
+ Axes of the data (e.g. SYX).
351
+ patch_size : List[int]
352
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
353
+ batch_size : int
354
+ Batch size.
355
+ num_epochs : int
356
+ Number of epochs.
357
+ use_augmentations : bool, optional
358
+ Whether to use augmentations, by default True.
359
+ use_n2v2 : bool, optional
360
+ Whether to use N2V2, by default False.
361
+ n_channels : int, optional
362
+ Number of channels (in and out), by default 1.
363
+ roi_size : int, optional
364
+ N2V pixel manipulation area, by default 11.
365
+ masked_pixel_percentage : float, optional
366
+ Percentage of pixels masked in each patch, by default 0.2.
367
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
368
+ Axis along which to apply structN2V mask, by default "none".
369
+ struct_n2v_span : int, optional
370
+ Span of the structN2V mask, by default 5.
371
+ logger : Literal["wandb", "tensorboard", "none"], optional
372
+ Logger to use, by default "none".
373
+ model_kwargs : dict, optional
374
+ UNetModel parameters, by default {}.
375
+
376
+ Returns
377
+ -------
378
+ Configuration
379
+ Configuration for training N2V.
380
+
381
+ Examples
382
+ --------
383
+ Minimum example:
384
+ >>> config = create_n2v_configuration(
385
+ ... experiment_name="n2v_experiment",
386
+ ... data_type="array",
387
+ ... axes="YX",
388
+ ... patch_size=[64, 64],
389
+ ... batch_size=32,
390
+ ... num_epochs=100
391
+ ... )
392
+
393
+ To use N2V2, simply pass the `use_n2v2` parameter:
394
+ >>> config = create_n2v_configuration(
395
+ ... experiment_name="n2v2_experiment",
396
+ ... data_type="tiff",
397
+ ... axes="YX",
398
+ ... patch_size=[64, 64],
399
+ ... batch_size=32,
400
+ ... num_epochs=100,
401
+ ... use_n2v2=True
402
+ ... )
403
+
404
+ For structN2V, there are two parameters to set, `struct_n2v_axis` and
405
+ `struct_n2v_span`:
406
+ >>> config = create_n2v_configuration(
407
+ ... experiment_name="structn2v_experiment",
408
+ ... data_type="tiff",
409
+ ... axes="YX",
410
+ ... patch_size=[64, 64],
411
+ ... batch_size=32,
412
+ ... num_epochs=100,
413
+ ... struct_n2v_axis="horizontal",
414
+ ... struct_n2v_span=7
415
+ ... )
416
+
417
+ If you are training multiple channels together, then you need to specify the number
418
+ of channels:
419
+ >>> config = create_n2v_configuration(
420
+ ... experiment_name="n2v_experiment",
421
+ ... data_type="array",
422
+ ... axes="YXC",
423
+ ... patch_size=[64, 64],
424
+ ... batch_size=32,
425
+ ... num_epochs=100,
426
+ ... n_channels=3
427
+ ... )
428
+
429
+ To turn off the augmentations, except normalization and N2V manipulation, use the
430
+ relevant keyword argument:
431
+ >>> config = create_n2v_configuration(
432
+ ... experiment_name="n2v_experiment",
433
+ ... data_type="array",
434
+ ... axes="YX",
435
+ ... patch_size=[64, 64],
436
+ ... batch_size=32,
437
+ ... num_epochs=100,
438
+ ... use_augmentations=False
439
+ ... )
440
+ """
441
+ # if there are channels, we need to specify their number
442
+ if "C" in axes and n_channels == 1:
443
+ raise ValueError(
444
+ f"Number of channels must be specified when using channels "
445
+ f"(got {n_channels} channel)."
446
+ )
447
+ elif "C" not in axes and n_channels > 1:
448
+ raise ValueError(
449
+ f"C is not present in the axes, but number of channels is specified "
450
+ f"(got {n_channels} channel)."
451
+ )
452
+
453
+ # model
454
+ if model_kwargs is None:
455
+ model_kwargs = {}
456
+ model_kwargs["n2v2"] = use_n2v2
457
+ model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
458
+ model_kwargs["in_channels"] = n_channels
459
+ model_kwargs["num_classes"] = n_channels
460
+
461
+ unet_model = UNetModel(
462
+ architecture=SupportedArchitecture.UNET.value,
463
+ **model_kwargs,
464
+ )
465
+
466
+ # algorithm model
467
+ algorithm = AlgorithmConfig(
468
+ algorithm=SupportedAlgorithm.N2V.value,
469
+ loss=SupportedLoss.N2V.value,
470
+ model=unet_model,
471
+ )
472
+
473
+ # augmentations
474
+ if use_augmentations:
475
+ transforms: List[Dict[str, Any]] = [
476
+ {
477
+ "name": SupportedTransform.NORMALIZE.value,
478
+ },
479
+ {
480
+ "name": SupportedTransform.NDFLIP.value,
481
+ },
482
+ {
483
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
484
+ },
485
+ ]
486
+ else:
487
+ transforms = [
488
+ {
489
+ "name": SupportedTransform.NORMALIZE.value,
490
+ },
491
+ ]
492
+
493
+ # n2v2 and structn2v
494
+ nv2_transform = {
495
+ "name": SupportedTransform.N2V_MANIPULATE.value,
496
+ "strategy": SupportedPixelManipulation.MEDIAN.value
497
+ if use_n2v2
498
+ else SupportedPixelManipulation.UNIFORM.value,
499
+ "roi_size": roi_size,
500
+ "masked_pixel_percentage": masked_pixel_percentage,
501
+ "struct_mask_axis": struct_n2v_axis,
502
+ "struct_mask_span": struct_n2v_span,
503
+ }
504
+ transforms.append(nv2_transform)
505
+
506
+ # data model
507
+ data = DataConfig(
508
+ data_type=data_type,
509
+ axes=axes,
510
+ patch_size=patch_size,
511
+ batch_size=batch_size,
512
+ transforms=transforms,
513
+ )
514
+
515
+ # training model
516
+ training = TrainingConfig(
517
+ num_epochs=num_epochs,
518
+ batch_size=batch_size,
519
+ logger=None if logger == "none" else logger,
520
+ )
521
+
522
+ # create configuration
523
+ configuration = Configuration(
524
+ experiment_name=experiment_name,
525
+ algorithm_config=algorithm,
526
+ data_config=data,
527
+ training_config=training,
528
+ )
529
+
530
+ return configuration
531
+
532
+
533
+ # TODO add tests
534
+ def create_inference_configuration(
535
+ training_configuration: Configuration,
536
+ tile_size: Optional[Tuple[int, ...]] = None,
537
+ tile_overlap: Optional[Tuple[int, ...]] = None,
538
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
539
+ axes: Optional[str] = None,
540
+ transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
541
+ tta_transforms: bool = True,
542
+ batch_size: Optional[int] = 1,
543
+ ) -> InferenceConfig:
544
+ """
545
+ Create a configuration for inference with N2V.
546
+
547
+ If not provided, `data_type` and `axes` are taken from the training
548
+ configuration. If `transforms` are not provided, only normalization is applied.
549
+
550
+ Parameters
551
+ ----------
552
+ training_configuration : Configuration
553
+ Configuration used for training.
554
+ tile_size : Tuple[int, ...], optional
555
+ Size of the tiles.
556
+ tile_overlap : Tuple[int, ...], optional
557
+ Overlap of the tiles.
558
+ data_type : str, optional
559
+ Type of the data, by default "tiff".
560
+ axes : str, optional
561
+ Axes of the data, by default "YX".
562
+ transforms : List[Dict[str, Any]] or Compose, optional
563
+ Transformations to apply to the data, by default None.
564
+ tta_transforms : bool, optional
565
+ Whether to apply test-time augmentations, by default True.
566
+ batch_size : int, optional
567
+ Batch size, by default 1.
568
+
569
+ Returns
570
+ -------
571
+ InferenceConfiguration
572
+ Configuration for inference with N2V.
573
+ """
574
+ if (
575
+ training_configuration.data_config.mean is None
576
+ or training_configuration.data_config.std is None
577
+ ):
578
+ raise ValueError("Mean and std must be provided in the training configuration.")
579
+
580
+ if transforms is None:
581
+ transforms = [
582
+ {
583
+ "name": SupportedTransform.NORMALIZE.value,
584
+ },
585
+ ]
586
+
587
+ return InferenceConfig(
588
+ data_type=data_type or training_configuration.data_config.data_type,
589
+ tile_size=tile_size,
590
+ tile_overlap=tile_overlap,
591
+ axes=axes or training_configuration.data_config.axes,
592
+ mean=training_configuration.data_config.mean,
593
+ std=training_configuration.data_config.std,
594
+ transforms=transforms,
595
+ tta_transforms=tta_transforms,
596
+ batch_size=batch_size,
597
+ )