careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,46 +1,250 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional
3
+ from typing import Any, Literal, Optional, Union
4
4
 
5
5
  from .architectures import UNetModel
6
6
  from .configuration_model import Configuration
7
7
  from .data_model import DataConfig
8
8
  from .fcn_algorithm_model import FCNAlgorithmConfig
9
9
  from .support import (
10
- SupportedAlgorithm,
11
10
  SupportedArchitecture,
12
- SupportedLoss,
13
11
  SupportedPixelManipulation,
14
12
  SupportedTransform,
15
13
  )
16
14
  from .training_model import TrainingConfig
15
+ from .transformations import (
16
+ N2VManipulateModel,
17
+ XYFlipModel,
18
+ XYRandomRotate90Model,
19
+ )
20
+
21
+
22
+ def _list_augmentations(
23
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]],
24
+ ) -> list[Union[XYFlipModel, XYRandomRotate90Model]]:
25
+ """
26
+ List the augmentations to apply.
27
+
28
+ Parameters
29
+ ----------
30
+ augmentations : list of transforms, optional
31
+ List of transforms to apply, either both or one of XYFlipModel and
32
+ XYRandomRotate90Model.
33
+
34
+ Returns
35
+ -------
36
+ list of transforms
37
+ List of transforms to apply.
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the transforms are not XYFlipModel or XYRandomRotate90Model.
43
+ ValueError
44
+ If there are duplicate transforms.
45
+ """
46
+ if augmentations is None:
47
+ transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [
48
+ XYFlipModel(),
49
+ XYRandomRotate90Model(),
50
+ ]
51
+ else:
52
+ # throw error if not all transforms are pydantic models
53
+ if not all(
54
+ isinstance(t, XYFlipModel) or isinstance(t, XYRandomRotate90Model)
55
+ for t in augmentations
56
+ ):
57
+ raise ValueError(
58
+ "Accepted transforms are either XYFlipModel or "
59
+ "XYRandomRotate90Model."
60
+ )
61
+
62
+ # check that there is no duplication
63
+ aug_types = [t.__class__ for t in augmentations]
64
+ if len(set(aug_types)) != len(aug_types):
65
+ raise ValueError("Duplicate transforms are not allowed.")
66
+
67
+ transform_list = augmentations
68
+
69
+ return transform_list
70
+
71
+
72
+ def _create_unet_configuration(
73
+ axes: str,
74
+ n_channels_in: int,
75
+ n_channels_out: int,
76
+ independent_channels: bool,
77
+ use_n2v2: bool,
78
+ model_params: Optional[dict[str, Any]] = None,
79
+ ) -> UNetModel:
80
+ """
81
+ Create a dictionary with the parameters of the UNet model.
82
+
83
+ Parameters
84
+ ----------
85
+ axes : str
86
+ Axes of the data.
87
+ n_channels_in : int
88
+ Number of input channels.
89
+ n_channels_out : int
90
+ Number of output channels.
91
+ independent_channels : bool
92
+ Whether to train all channels independently.
93
+ use_n2v2 : bool
94
+ Whether to use N2V2.
95
+ model_params : dict
96
+ UNetModel parameters.
97
+
98
+ Returns
99
+ -------
100
+ UNetModel
101
+ UNet model with the specified parameters.
102
+ """
103
+ if model_params is None:
104
+ model_params = {}
105
+
106
+ model_params["n2v2"] = use_n2v2
107
+ model_params["conv_dims"] = 3 if "Z" in axes else 2
108
+ model_params["in_channels"] = n_channels_in
109
+ model_params["num_classes"] = n_channels_out
110
+ model_params["independent_channels"] = independent_channels
111
+
112
+ return UNetModel(
113
+ architecture=SupportedArchitecture.UNET.value,
114
+ **model_params,
115
+ )
116
+
117
+
118
+ def _create_configuration(
119
+ algorithm: Literal["n2v", "care", "n2n"],
120
+ experiment_name: str,
121
+ data_type: Literal["array", "tiff", "custom"],
122
+ axes: str,
123
+ patch_size: list[int],
124
+ batch_size: int,
125
+ num_epochs: int,
126
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]],
127
+ independent_channels: bool,
128
+ loss: Literal["n2v", "mae", "mse"],
129
+ n_channels_in: int,
130
+ n_channels_out: int,
131
+ logger: Literal["wandb", "tensorboard", "none"],
132
+ use_n2v2: bool = False,
133
+ model_params: Optional[dict] = None,
134
+ dataloader_params: Optional[dict] = None,
135
+ ) -> Configuration:
136
+ """
137
+ Create a configuration for training N2V, CARE or Noise2Noise.
17
138
 
139
+ Parameters
140
+ ----------
141
+ algorithm : {"n2v", "care", "n2n"}
142
+ Algorithm to use.
143
+ experiment_name : str
144
+ Name of the experiment.
145
+ data_type : {"array", "tiff", "custom"}
146
+ Type of the data.
147
+ axes : str
148
+ Axes of the data (e.g. SYX).
149
+ patch_size : list of int
150
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
151
+ batch_size : int
152
+ Batch size.
153
+ num_epochs : int
154
+ Number of epochs.
155
+ augmentations : list of transforms
156
+ List of transforms to apply, either both or one of XYFlipModel and
157
+ XYRandomRotate90Model.
158
+ independent_channels : bool
159
+ Whether to train all channels independently.
160
+ loss : {"n2v", "mae", "mse"}
161
+ Loss function to use.
162
+ n_channels_in : int
163
+ Number of channels in.
164
+ n_channels_out : int
165
+ Number of channels out.
166
+ logger : {"wandb", "tensorboard", "none"}
167
+ Logger to use.
168
+ use_n2v2 : bool, optional
169
+ Whether to use N2V2, by default False.
170
+ model_params : dict
171
+ UNetModel parameters.
172
+ dataloader_params : dict
173
+ Parameters for the dataloader, see PyTorch notes, by default None.
18
174
 
19
- # TODO rename ?
175
+ Returns
176
+ -------
177
+ Configuration
178
+ Configuration for training N2V, CARE or Noise2Noise.
179
+ """
180
+ # model
181
+ unet_model = _create_unet_configuration(
182
+ axes=axes,
183
+ n_channels_in=n_channels_in,
184
+ n_channels_out=n_channels_out,
185
+ independent_channels=independent_channels,
186
+ use_n2v2=use_n2v2,
187
+ model_params=model_params,
188
+ )
189
+
190
+ # algorithm model
191
+ algorithm_config = FCNAlgorithmConfig(
192
+ algorithm=algorithm,
193
+ loss=loss,
194
+ model=unet_model,
195
+ )
196
+
197
+ # data model
198
+ data = DataConfig(
199
+ data_type=data_type,
200
+ axes=axes,
201
+ patch_size=patch_size,
202
+ batch_size=batch_size,
203
+ transforms=augmentations,
204
+ dataloader_params=dataloader_params,
205
+ )
206
+
207
+ # training model
208
+ training = TrainingConfig(
209
+ num_epochs=num_epochs,
210
+ batch_size=batch_size,
211
+ logger=None if logger == "none" else logger,
212
+ )
213
+
214
+ # create configuration
215
+ configuration = Configuration(
216
+ experiment_name=experiment_name,
217
+ algorithm_config=algorithm_config,
218
+ data_config=data,
219
+ training_config=training,
220
+ )
221
+
222
+ return configuration
223
+
224
+
225
+ # TODO reconsider naming once we officially support LVAE approaches
20
226
  def _create_supervised_configuration(
21
- algorithm_type: Literal["fcn"],
22
227
  algorithm: Literal["care", "n2n"],
23
228
  experiment_name: str,
24
229
  data_type: Literal["array", "tiff", "custom"],
25
230
  axes: str,
26
- patch_size: List[int],
231
+ patch_size: list[int],
27
232
  batch_size: int,
28
233
  num_epochs: int,
29
- use_augmentations: bool = True,
30
- independent_channels: bool = False,
234
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
235
+ independent_channels: bool = True,
31
236
  loss: Literal["mae", "mse"] = "mae",
32
237
  n_channels_in: int = 1,
33
238
  n_channels_out: int = 1,
34
239
  logger: Literal["wandb", "tensorboard", "none"] = "none",
35
- model_kwargs: Optional[dict] = None,
240
+ model_params: Optional[dict] = None,
241
+ dataloader_params: Optional[dict] = None,
36
242
  ) -> Configuration:
37
243
  """
38
244
  Create a configuration for training CARE or Noise2Noise.
39
245
 
40
246
  Parameters
41
247
  ----------
42
- algorithm_type : Literal["fcn"]
43
- Type of the algorithm.
44
248
  algorithm : Literal["care", "n2n"]
45
249
  Algorithm to use.
46
250
  experiment_name : str
@@ -55,8 +259,10 @@ def _create_supervised_configuration(
55
259
  Batch size.
56
260
  num_epochs : int
57
261
  Number of epochs.
58
- use_augmentations : bool, optional
59
- Whether to use augmentations, by default True.
262
+ augmentations : list of transforms, default=None
263
+ List of transforms to apply, either both or one of XYFlipModel and
264
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
265
+ and XYRandomRotate90 (in XY) to the images.
60
266
  independent_channels : bool, optional
61
267
  Whether to train all channels independently, by default False.
62
268
  loss : Literal["mae", "mse"], optional
@@ -67,8 +273,10 @@ def _create_supervised_configuration(
67
273
  Number of channels out, by default 1.
68
274
  logger : Literal["wandb", "tensorboard", "none"], optional
69
275
  Logger to use, by default "none".
70
- model_kwargs : dict, optional
276
+ model_params : dict, optional
71
277
  UNetModel parameters, by default {}.
278
+ dataloader_params : dict, optional
279
+ Parameters for the dataloader, see PyTorch notes, by default None.
72
280
 
73
281
  Returns
74
282
  -------
@@ -87,81 +295,43 @@ def _create_supervised_configuration(
87
295
  f"(got {n_channels_in} channels)."
88
296
  )
89
297
 
90
- # model
91
- if model_kwargs is None:
92
- model_kwargs = {}
93
- model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
94
- model_kwargs["in_channels"] = n_channels_in
95
- model_kwargs["num_classes"] = n_channels_out
96
- model_kwargs["independent_channels"] = independent_channels
97
-
98
- unet_model = UNetModel(
99
- architecture=SupportedArchitecture.UNET.value,
100
- **model_kwargs,
101
- )
102
-
103
- # algorithm model
104
- algorithm = FCNAlgorithmConfig(
105
- algorithm_type=algorithm_type,
106
- algorithm=algorithm,
107
- loss=loss,
108
- model=unet_model,
109
- )
110
-
111
298
  # augmentations
112
- if use_augmentations:
113
- transforms: List[Dict[str, Any]] = [
114
- {
115
- "name": SupportedTransform.XY_FLIP.value,
116
- },
117
- {
118
- "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
119
- },
120
- ]
121
- else:
122
- transforms = []
299
+ transform_list = _list_augmentations(augmentations)
123
300
 
124
- # data model
125
- data = DataConfig(
301
+ return _create_configuration(
302
+ algorithm=algorithm,
303
+ experiment_name=experiment_name,
126
304
  data_type=data_type,
127
305
  axes=axes,
128
306
  patch_size=patch_size,
129
307
  batch_size=batch_size,
130
- transforms=transforms,
131
- )
132
-
133
- # training model
134
- training = TrainingConfig(
135
308
  num_epochs=num_epochs,
136
- batch_size=batch_size,
137
- logger=None if logger == "none" else logger,
138
- )
139
-
140
- # create configuration
141
- configuration = Configuration(
142
- experiment_name=experiment_name,
143
- algorithm_config=algorithm,
144
- data_config=data,
145
- training_config=training,
309
+ augmentations=transform_list,
310
+ independent_channels=independent_channels,
311
+ loss=loss,
312
+ n_channels_in=n_channels_in,
313
+ n_channels_out=n_channels_out,
314
+ logger=logger,
315
+ model_params=model_params,
316
+ dataloader_params=dataloader_params,
146
317
  )
147
318
 
148
- return configuration
149
-
150
319
 
151
320
  def create_care_configuration(
152
321
  experiment_name: str,
153
322
  data_type: Literal["array", "tiff", "custom"],
154
323
  axes: str,
155
- patch_size: List[int],
324
+ patch_size: list[int],
156
325
  batch_size: int,
157
326
  num_epochs: int,
158
- use_augmentations: bool = True,
159
- independent_channels: bool = False,
327
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
328
+ independent_channels: bool = True,
160
329
  loss: Literal["mae", "mse"] = "mae",
161
330
  n_channels_in: int = 1,
162
331
  n_channels_out: int = -1,
163
332
  logger: Literal["wandb", "tensorboard", "none"] = "none",
164
- model_kwargs: Optional[dict] = None,
333
+ model_params: Optional[dict] = None,
334
+ dataloader_params: Optional[dict] = None,
165
335
  ) -> Configuration:
166
336
  """
167
337
  Create a configuration for training CARE.
@@ -179,8 +349,10 @@ def create_care_configuration(
179
349
  By default, all channels are trained together. To train all channels independently,
180
350
  set `independent_channels` to True.
181
351
 
182
- By setting `use_augmentations` to False, the only transformation applied will be
183
- normalization.
352
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
353
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
354
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
355
+ disable the transforms, simply pass an empty list.
184
356
 
185
357
  Parameters
186
358
  ----------
@@ -196,8 +368,10 @@ def create_care_configuration(
196
368
  Batch size.
197
369
  num_epochs : int
198
370
  Number of epochs.
199
- use_augmentations : bool, optional
200
- Whether to use augmentations, by default True.
371
+ augmentations : list of transforms, default=None
372
+ List of transforms to apply, either both or one of XYFlipModel and
373
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
374
+ and XYRandomRotate90 (in XY) to the images.
201
375
  independent_channels : bool, optional
202
376
  Whether to train all channels independently, by default False.
203
377
  loss : Literal["mae", "mse"], optional
@@ -208,19 +382,87 @@ def create_care_configuration(
208
382
  Number of channels out, by default -1.
209
383
  logger : Literal["wandb", "tensorboard", "none"], optional
210
384
  Logger to use, by default "none".
211
- model_kwargs : dict, optional
212
- UNetModel parameters, by default {}.
385
+ model_params : dict, optional
386
+ UNetModel parameters, by default None.
387
+ dataloader_params : dict, optional
388
+ Parameters for the dataloader, see PyTorch notes, by default None.
213
389
 
214
390
  Returns
215
391
  -------
216
392
  Configuration
217
393
  Configuration for training CARE.
394
+
395
+ Examples
396
+ --------
397
+ Minimum example:
398
+ >>> config = create_care_configuration(
399
+ ... experiment_name="care_experiment",
400
+ ... data_type="array",
401
+ ... axes="YX",
402
+ ... patch_size=[64, 64],
403
+ ... batch_size=32,
404
+ ... num_epochs=100
405
+ ... )
406
+
407
+ To disable transforms, simply set `augmentations` to an empty list:
408
+ >>> config = create_care_configuration(
409
+ ... experiment_name="care_experiment",
410
+ ... data_type="array",
411
+ ... axes="YX",
412
+ ... patch_size=[64, 64],
413
+ ... batch_size=32,
414
+ ... num_epochs=100,
415
+ ... augmentations=[]
416
+ ... )
417
+
418
+ A list of transforms can be passed to the `augmentations` parameter to replace the
419
+ default augmentations:
420
+ >>> from careamics.config.transformations import XYFlipModel
421
+ >>> config = create_care_configuration(
422
+ ... experiment_name="care_experiment",
423
+ ... data_type="array",
424
+ ... axes="YX",
425
+ ... patch_size=[64, 64],
426
+ ... batch_size=32,
427
+ ... num_epochs=100,
428
+ ... augmentations=[
429
+ ... # No rotation and only Y flipping
430
+ ... XYFlipModel(flip_x = False, flip_y = True)
431
+ ... ]
432
+ ... )
433
+
434
+ If you are training multiple channels they will be trained independently by default,
435
+ you simply need to specify the number of channels input (and optionally, the number
436
+ of channels output):
437
+ >>> config = create_care_configuration(
438
+ ... experiment_name="care_experiment",
439
+ ... data_type="array",
440
+ ... axes="YXC", # channels must be in the axes
441
+ ... patch_size=[64, 64],
442
+ ... batch_size=32,
443
+ ... num_epochs=100,
444
+ ... n_channels_in=3, # number of input channels
445
+ ... n_channels_out=1 # if applicable
446
+ ... )
447
+
448
+ If instead you want to train multiple channels together, you need to turn off the
449
+ `independent_channels` parameter:
450
+ >>> config = create_care_configuration(
451
+ ... experiment_name="care_experiment",
452
+ ... data_type="array",
453
+ ... axes="YXC", # channels must be in the axes
454
+ ... patch_size=[64, 64],
455
+ ... batch_size=32,
456
+ ... num_epochs=100,
457
+ ... independent_channels=False,
458
+ ... n_channels_in=3,
459
+ ... n_channels_out=1 # if applicable
460
+ ... )
218
461
  """
219
462
  if n_channels_out == -1:
220
463
  n_channels_out = n_channels_in
221
464
 
222
465
  return _create_supervised_configuration(
223
- algorithm_type="fcn",
224
466
  algorithm="care",
225
467
  experiment_name=experiment_name,
226
468
  data_type=data_type,
@@ -228,13 +470,14 @@ def create_care_configuration(
228
470
  patch_size=patch_size,
229
471
  batch_size=batch_size,
230
472
  num_epochs=num_epochs,
231
- use_augmentations=use_augmentations,
473
+ augmentations=augmentations,
232
474
  independent_channels=independent_channels,
233
475
  loss=loss,
234
476
  n_channels_in=n_channels_in,
235
477
  n_channels_out=n_channels_out,
236
478
  logger=logger,
237
- model_kwargs=model_kwargs,
479
+ model_params=model_params,
480
+ dataloader_params=dataloader_params,
238
481
  )
239
482
 
240
483
 
@@ -242,16 +485,17 @@ def create_n2n_configuration(
242
485
  experiment_name: str,
243
486
  data_type: Literal["array", "tiff", "custom"],
244
487
  axes: str,
245
- patch_size: List[int],
488
+ patch_size: list[int],
246
489
  batch_size: int,
247
490
  num_epochs: int,
248
- use_augmentations: bool = True,
249
- independent_channels: bool = False,
491
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
492
+ independent_channels: bool = True,
250
493
  loss: Literal["mae", "mse"] = "mae",
251
494
  n_channels_in: int = 1,
252
495
  n_channels_out: int = -1,
253
496
  logger: Literal["wandb", "tensorboard", "none"] = "none",
254
- model_kwargs: Optional[dict] = None,
497
+ model_params: Optional[dict] = None,
498
+ dataloader_params: Optional[dict] = None,
255
499
  ) -> Configuration:
256
500
  """
257
501
  Create a configuration for training Noise2Noise.
@@ -269,8 +513,10 @@ def create_n2n_configuration(
269
513
  By default, all channels are trained together. To train all channels independently,
270
514
  set `independent_channels` to True.
271
515
 
272
- By setting `use_augmentations` to False, the only transformation applied will be
273
- normalization.
516
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
517
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
518
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
519
+ disable the transforms, simply pass an empty list.
274
520
 
275
521
  Parameters
276
522
  ----------
@@ -286,8 +532,10 @@ def create_n2n_configuration(
286
532
  Batch size.
287
533
  num_epochs : int
288
534
  Number of epochs.
289
- use_augmentations : bool, optional
290
- Whether to use augmentations, by default True.
535
+ augmentations : list of transforms, default=None
536
+ List of transforms to apply, either both or one of XYFlipModel and
537
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
538
+ and XYRandomRotate90 (in XY) to the images.
291
539
  independent_channels : bool, optional
292
540
  Whether to train all channels independently, by default False.
293
541
  loss : Literal["mae", "mse"], optional
@@ -298,19 +546,87 @@ def create_n2n_configuration(
298
546
  Number of channels out, by default -1.
299
547
  logger : Literal["wandb", "tensorboard", "none"], optional
300
548
  Logger to use, by default "none".
301
- model_kwargs : dict, optional
549
+ model_params : dict, optional
302
550
  UNetModel parameters, by default {}.
551
+ dataloader_params : dict, optional
552
+ Parameters for the dataloader, see PyTorch notes, by default None.
303
553
 
304
554
  Returns
305
555
  -------
306
556
  Configuration
307
557
  Configuration for training Noise2Noise.
558
+
559
+ Examples
560
+ --------
561
+ Minimum example:
562
+ >>> config = create_n2n_configuration(
563
+ ... experiment_name="n2n_experiment",
564
+ ... data_type="array",
565
+ ... axes="YX",
566
+ ... patch_size=[64, 64],
567
+ ... batch_size=32,
568
+ ... num_epochs=100
569
+ ... )
570
+
571
+ To disable transforms, simply set `augmentations` to an empty list:
572
+ >>> config = create_n2n_configuration(
573
+ ... experiment_name="n2n_experiment",
574
+ ... data_type="array",
575
+ ... axes="YX",
576
+ ... patch_size=[64, 64],
577
+ ... batch_size=32,
578
+ ... num_epochs=100,
579
+ ... augmentations=[]
580
+ ... )
581
+
582
+ A list of transforms can be passed to the `augmentations` parameter to replace the
583
+ default augmentations:
584
+ >>> from careamics.config.transformations import XYFlipModel
585
+ >>> config = create_n2n_configuration(
586
+ ... experiment_name="n2n_experiment",
587
+ ... data_type="array",
588
+ ... axes="YX",
589
+ ... patch_size=[64, 64],
590
+ ... batch_size=32,
591
+ ... num_epochs=100,
592
+ ... augmentations=[
593
+ ... # No rotation and only Y flipping
594
+ ... XYFlipModel(flip_x = False, flip_y = True)
595
+ ... ]
596
+ ... )
597
+
598
+ If you are training multiple channels they will be trained independently by default,
599
+ you simply need to specify the number of channels input (and optionally, the number
600
+ of channels output):
601
+ >>> config = create_n2n_configuration(
602
+ ... experiment_name="n2n_experiment",
603
+ ... data_type="array",
604
+ ... axes="YXC", # channels must be in the axes
605
+ ... patch_size=[64, 64],
606
+ ... batch_size=32,
607
+ ... num_epochs=100,
608
+ ... n_channels_in=3, # number of input channels
609
+ ... n_channels_out=1 # if applicable
610
+ ... )
611
+
612
+ If instead you want to train multiple channels together, you need to turn off the
613
+ `independent_channels` parameter:
614
+ >>> config = create_n2n_configuration(
615
+ ... experiment_name="n2n_experiment",
616
+ ... data_type="array",
617
+ ... axes="YXC", # channels must be in the axes
618
+ ... patch_size=[64, 64],
619
+ ... batch_size=32,
620
+ ... num_epochs=100,
621
+ ... independent_channels=False,
622
+ ... n_channels_in=3,
623
+ ... n_channels_out=1 # if applicable
624
+ ... )
308
625
  """
309
626
  if n_channels_out == -1:
310
627
  n_channels_out = n_channels_in
311
628
 
312
629
  return _create_supervised_configuration(
313
- algorithm_type="fcn",
314
630
  algorithm="n2n",
315
631
  experiment_name=experiment_name,
316
632
  data_type=data_type,
@@ -318,13 +634,14 @@ def create_n2n_configuration(
318
634
  patch_size=patch_size,
319
635
  batch_size=batch_size,
320
636
  num_epochs=num_epochs,
321
- use_augmentations=use_augmentations,
637
+ augmentations=augmentations,
322
638
  independent_channels=independent_channels,
323
639
  loss=loss,
324
640
  n_channels_in=n_channels_in,
325
641
  n_channels_out=n_channels_out,
326
642
  logger=logger,
327
- model_kwargs=model_kwargs,
643
+ model_params=model_params,
644
+ dataloader_params=dataloader_params,
328
645
  )
329
646
 
330
647
 
@@ -332,10 +649,10 @@ def create_n2v_configuration(
332
649
  experiment_name: str,
333
650
  data_type: Literal["array", "tiff", "custom"],
334
651
  axes: str,
335
- patch_size: List[int],
652
+ patch_size: list[int],
336
653
  batch_size: int,
337
654
  num_epochs: int,
338
- use_augmentations: bool = True,
655
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
339
656
  independent_channels: bool = True,
340
657
  use_n2v2: bool = False,
341
658
  n_channels: int = 1,
@@ -344,7 +661,8 @@ def create_n2v_configuration(
344
661
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
345
662
  struct_n2v_span: int = 5,
346
663
  logger: Literal["wandb", "tensorboard", "none"] = "none",
347
- model_kwargs: Optional[dict] = None,
664
+ model_params: Optional[dict] = None,
665
+ dataloader_params: Optional[dict] = None,
348
666
  ) -> Configuration:
349
667
  """
350
668
  Create a configuration for training Noise2Void.
@@ -367,16 +685,22 @@ def create_n2v_configuration(
367
685
  By default, all channels are trained independently. To train all channels together,
368
686
  set `independent_channels` to False.
369
687
 
370
- By setting `use_augmentations` to False, the only transformations applied will be
371
- normalization and N2V manipulation.
688
+ By default, the transformations applied are a random flip along X or Y, and a random
689
+ 90 degrees rotation in the XY plane. Normalization is always applied, as well as the
690
+ N2V manipulation.
691
+
692
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
693
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
694
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
695
+ disable the transforms, simply pass an empty list.
372
696
 
373
697
  The `roi_size` parameter specifies the size of the area around each pixel that will
374
698
  be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
375
699
  pixels per patch will be manipulated.
376
700
 
377
- The parameters of the UNet can be specified in the `model_kwargs` (passed as a
701
+ The parameters of the UNet can be specified in the `model_params` (passed as a
378
702
  parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
379
- corresponding parameters passed in `model_kwargs`.
703
+ corresponding parameters passed in `model_params`.
380
704
 
381
705
  If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
382
706
  will be applied to each manipulated pixel.
@@ -395,8 +719,10 @@ def create_n2v_configuration(
395
719
  Batch size.
396
720
  num_epochs : int
397
721
  Number of epochs.
398
- use_augmentations : bool, optional
399
- Whether to use augmentations, by default True.
722
+ augmentations : list of transforms, default=None
723
+ List of transforms to apply, either both or one of XYFlipModel and
724
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
725
+ and XYRandomRotate90 (in XY) to the images.
400
726
  independent_channels : bool, optional
401
727
  Whether to train all channels together, by default True.
402
728
  use_n2v2 : bool, optional
@@ -413,8 +739,10 @@ def create_n2v_configuration(
413
739
  Span of the structN2V mask, by default 5.
414
740
  logger : Literal["wandb", "tensorboard", "none"], optional
415
741
  Logger to use, by default "none".
416
- model_kwargs : dict, optional
417
- UNetModel parameters, by default {}.
742
+ model_params : dict, optional
743
+ UNetModel parameters, by default None.
744
+ dataloader_params : dict, optional
745
+ Parameters for the dataloader, see PyTorch notes, by default None.
418
746
 
419
747
  Returns
420
748
  -------
@@ -433,6 +761,32 @@ def create_n2v_configuration(
433
761
  ... num_epochs=100
434
762
  ... )
435
763
 
764
+ To disable transforms, simply set `augmentations` to an empty list:
765
+ >>> config = create_n2v_configuration(
766
+ ... experiment_name="n2v_experiment",
767
+ ... data_type="array",
768
+ ... axes="YX",
769
+ ... patch_size=[64, 64],
770
+ ... batch_size=32,
771
+ ... num_epochs=100,
772
+ ... augmentations=[]
773
+ ... )
774
+
775
+ A list of transforms can be passed to the `augmentations` parameter:
776
+ >>> from careamics.config.transformations import XYFlipModel
777
+ >>> config = create_n2v_configuration(
778
+ ... experiment_name="n2v_experiment",
779
+ ... data_type="array",
780
+ ... axes="YX",
781
+ ... patch_size=[64, 64],
782
+ ... batch_size=32,
783
+ ... num_epochs=100,
784
+ ... augmentations=[
785
+ ... # No rotation and only Y flipping
786
+ ... XYFlipModel(flip_x = False, flip_y = True)
787
+ ... ]
788
+ ... )
789
+
436
790
  To use N2V2, simply pass the `use_n2v2` parameter:
437
791
  >>> config = create_n2v_configuration(
438
792
  ... experiment_name="n2v2_experiment",
@@ -457,8 +811,8 @@ def create_n2v_configuration(
457
811
  ... struct_n2v_span=7
458
812
  ... )
459
813
 
460
- If you are training multiple channels independently, then you need to specify the
461
- number of channels:
814
+ If you are training multiple channels they will be trained independently by default,
815
+ you simply need to specify the number of channels:
462
816
  >>> config = create_n2v_configuration(
463
817
  ... experiment_name="n2v_experiment",
464
818
  ... data_type="array",
@@ -481,18 +835,6 @@ def create_n2v_configuration(
481
835
  ... independent_channels=False,
482
836
  ... n_channels=3
483
837
  ... )
484
-
485
- To turn off the augmentations, except normalization and N2V manipulation, use the
486
- relevant keyword argument:
487
- >>> config = create_n2v_configuration(
488
- ... experiment_name="n2v_experiment",
489
- ... data_type="array",
490
- ... axes="YX",
491
- ... patch_size=[64, 64],
492
- ... batch_size=32,
493
- ... num_epochs=100,
494
- ... use_augmentations=False
495
- ... )
496
838
  """
497
839
  # if there are channels, we need to specify their number
498
840
  if "C" in axes and n_channels == 1:
@@ -506,78 +848,39 @@ def create_n2v_configuration(
506
848
  f"(got {n_channels} channel)."
507
849
  )
508
850
 
509
- # model
510
- if model_kwargs is None:
511
- model_kwargs = {}
512
- model_kwargs["n2v2"] = use_n2v2
513
- model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
514
- model_kwargs["in_channels"] = n_channels
515
- model_kwargs["num_classes"] = n_channels
516
- model_kwargs["independent_channels"] = independent_channels
517
-
518
- unet_model = UNetModel(
519
- architecture=SupportedArchitecture.UNET.value,
520
- **model_kwargs,
521
- )
522
-
523
- # algorithm model
524
- algorithm = FCNAlgorithmConfig(
525
- algorithm_type="fcn",
526
- algorithm=SupportedAlgorithm.N2V.value,
527
- loss=SupportedLoss.N2V.value,
528
- model=unet_model,
529
- )
530
-
531
851
  # augmentations
532
- if use_augmentations:
533
- transforms: List[Dict[str, Any]] = [
534
- {
535
- "name": SupportedTransform.XY_FLIP.value,
536
- },
537
- {
538
- "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
539
- },
540
- ]
541
- else:
542
- transforms = []
852
+ transform_list = _list_augmentations(augmentations)
543
853
 
544
- # n2v2 and structn2v
545
- nv2_transform = {
546
- "name": SupportedTransform.N2V_MANIPULATE.value,
547
- "strategy": (
854
+ # create the N2VManipulate transform using the supplied parameters
855
+ n2v_transform = N2VManipulateModel(
856
+ name=SupportedTransform.N2V_MANIPULATE.value,
857
+ strategy=(
548
858
  SupportedPixelManipulation.MEDIAN.value
549
859
  if use_n2v2
550
860
  else SupportedPixelManipulation.UNIFORM.value
551
861
  ),
552
- "roi_size": roi_size,
553
- "masked_pixel_percentage": masked_pixel_percentage,
554
- "struct_mask_axis": struct_n2v_axis,
555
- "struct_mask_span": struct_n2v_span,
556
- }
557
- transforms.append(nv2_transform)
862
+ roi_size=roi_size,
863
+ masked_pixel_percentage=masked_pixel_percentage,
864
+ struct_mask_axis=struct_n2v_axis,
865
+ struct_mask_span=struct_n2v_span,
866
+ )
867
+ transform_list.append(n2v_transform)
558
868
 
559
- # data model
560
- data = DataConfig(
869
+ return _create_configuration(
870
+ algorithm="n2v",
871
+ experiment_name=experiment_name,
561
872
  data_type=data_type,
562
873
  axes=axes,
563
874
  patch_size=patch_size,
564
875
  batch_size=batch_size,
565
- transforms=transforms,
566
- )
567
-
568
- # training model
569
- training = TrainingConfig(
570
876
  num_epochs=num_epochs,
571
- batch_size=batch_size,
572
- logger=None if logger == "none" else logger,
573
- )
574
-
575
- # create configuration
576
- configuration = Configuration(
577
- experiment_name=experiment_name,
578
- algorithm_config=algorithm,
579
- data_config=data,
580
- training_config=training,
877
+ augmentations=transform_list,
878
+ independent_channels=independent_channels,
879
+ loss="n2v",
880
+ use_n2v2=use_n2v2,
881
+ n_channels_in=n_channels,
882
+ n_channels_out=n_channels,
883
+ logger=logger,
884
+ model_params=model_params,
885
+ dataloader_params=dataloader_params,
581
886
  )
582
-
583
- return configuration