careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,36 +1,244 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional
3
+ from typing import Any, Literal, Optional, Union
4
4
 
5
- from .algorithm_model import AlgorithmConfig
6
5
  from .architectures import UNetModel
7
6
  from .configuration_model import Configuration
8
7
  from .data_model import DataConfig
8
+ from .fcn_algorithm_model import FCNAlgorithmConfig
9
9
  from .support import (
10
- SupportedAlgorithm,
11
10
  SupportedArchitecture,
12
- SupportedLoss,
13
11
  SupportedPixelManipulation,
14
12
  SupportedTransform,
15
13
  )
16
14
  from .training_model import TrainingConfig
15
+ from .transformations import (
16
+ N2VManipulateModel,
17
+ XYFlipModel,
18
+ XYRandomRotate90Model,
19
+ )
20
+
21
+
22
+ def _list_augmentations(
23
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]],
24
+ ) -> list[Union[XYFlipModel, XYRandomRotate90Model]]:
25
+ """
26
+ List the augmentations to apply.
27
+
28
+ Parameters
29
+ ----------
30
+ augmentations : list of transforms, optional
31
+ List of transforms to apply, either both or one of XYFlipModel and
32
+ XYRandomRotate90Model.
33
+
34
+ Returns
35
+ -------
36
+ list of transforms
37
+ List of transforms to apply.
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the transforms are not XYFlipModel or XYRandomRotate90Model.
43
+ ValueError
44
+ If there are duplicate transforms.
45
+ """
46
+ if augmentations is None:
47
+ transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [
48
+ XYFlipModel(),
49
+ XYRandomRotate90Model(),
50
+ ]
51
+ else:
52
+ # throw error if not all transforms are pydantic models
53
+ if not all(
54
+ isinstance(t, XYFlipModel) or isinstance(t, XYRandomRotate90Model)
55
+ for t in augmentations
56
+ ):
57
+ raise ValueError(
58
+ "Accepted transforms are either XYFlipModel or "
59
+ "XYRandomRotate90Model."
60
+ )
61
+
62
+ # check that there is no duplication
63
+ aug_types = [t.__class__ for t in augmentations]
64
+ if len(set(aug_types)) != len(aug_types):
65
+ raise ValueError("Duplicate transforms are not allowed.")
66
+
67
+ transform_list = augmentations
68
+
69
+ return transform_list
70
+
71
+
72
+ def _create_unet_configuration(
73
+ axes: str,
74
+ n_channels_in: int,
75
+ n_channels_out: int,
76
+ independent_channels: bool,
77
+ use_n2v2: bool,
78
+ model_params: Optional[dict[str, Any]] = None,
79
+ ) -> UNetModel:
80
+ """
81
+ Create a dictionary with the parameters of the UNet model.
82
+
83
+ Parameters
84
+ ----------
85
+ axes : str
86
+ Axes of the data.
87
+ n_channels_in : int
88
+ Number of input channels.
89
+ n_channels_out : int
90
+ Number of output channels.
91
+ independent_channels : bool
92
+ Whether to train all channels independently.
93
+ use_n2v2 : bool
94
+ Whether to use N2V2.
95
+ model_params : dict
96
+ UNetModel parameters.
97
+
98
+ Returns
99
+ -------
100
+ UNetModel
101
+ UNet model with the specified parameters.
102
+ """
103
+ if model_params is None:
104
+ model_params = {}
105
+
106
+ model_params["n2v2"] = use_n2v2
107
+ model_params["conv_dims"] = 3 if "Z" in axes else 2
108
+ model_params["in_channels"] = n_channels_in
109
+ model_params["num_classes"] = n_channels_out
110
+ model_params["independent_channels"] = independent_channels
111
+
112
+ return UNetModel(
113
+ architecture=SupportedArchitecture.UNET.value,
114
+ **model_params,
115
+ )
116
+
117
+
118
+ def _create_configuration(
119
+ algorithm: Literal["n2v", "care", "n2n"],
120
+ experiment_name: str,
121
+ data_type: Literal["array", "tiff", "custom"],
122
+ axes: str,
123
+ patch_size: list[int],
124
+ batch_size: int,
125
+ num_epochs: int,
126
+ augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]],
127
+ independent_channels: bool,
128
+ loss: Literal["n2v", "mae", "mse"],
129
+ n_channels_in: int,
130
+ n_channels_out: int,
131
+ logger: Literal["wandb", "tensorboard", "none"],
132
+ use_n2v2: bool = False,
133
+ model_params: Optional[dict] = None,
134
+ dataloader_params: Optional[dict] = None,
135
+ ) -> Configuration:
136
+ """
137
+ Create a configuration for training N2V, CARE or Noise2Noise.
17
138
 
139
+ Parameters
140
+ ----------
141
+ algorithm : {"n2v", "care", "n2n"}
142
+ Algorithm to use.
143
+ experiment_name : str
144
+ Name of the experiment.
145
+ data_type : {"array", "tiff", "custom"}
146
+ Type of the data.
147
+ axes : str
148
+ Axes of the data (e.g. SYX).
149
+ patch_size : list of int
150
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
151
+ batch_size : int
152
+ Batch size.
153
+ num_epochs : int
154
+ Number of epochs.
155
+ augmentations : list of transforms
156
+ List of transforms to apply, either both or one of XYFlipModel and
157
+ XYRandomRotate90Model.
158
+ independent_channels : bool
159
+ Whether to train all channels independently.
160
+ loss : {"n2v", "mae", "mse"}
161
+ Loss function to use.
162
+ n_channels_in : int
163
+ Number of channels in.
164
+ n_channels_out : int
165
+ Number of channels out.
166
+ logger : {"wandb", "tensorboard", "none"}
167
+ Logger to use.
168
+ use_n2v2 : bool, optional
169
+ Whether to use N2V2, by default False.
170
+ model_params : dict
171
+ UNetModel parameters.
172
+ dataloader_params : dict
173
+ Parameters for the dataloader, see PyTorch notes, by default None.
18
174
 
175
+ Returns
176
+ -------
177
+ Configuration
178
+ Configuration for training N2V, CARE or Noise2Noise.
179
+ """
180
+ # model
181
+ unet_model = _create_unet_configuration(
182
+ axes=axes,
183
+ n_channels_in=n_channels_in,
184
+ n_channels_out=n_channels_out,
185
+ independent_channels=independent_channels,
186
+ use_n2v2=use_n2v2,
187
+ model_params=model_params,
188
+ )
189
+
190
+ # algorithm model
191
+ algorithm_config = FCNAlgorithmConfig(
192
+ algorithm=algorithm,
193
+ loss=loss,
194
+ model=unet_model,
195
+ )
196
+
197
+ # data model
198
+ data = DataConfig(
199
+ data_type=data_type,
200
+ axes=axes,
201
+ patch_size=patch_size,
202
+ batch_size=batch_size,
203
+ transforms=augmentations,
204
+ dataloader_params=dataloader_params,
205
+ )
206
+
207
+ # training model
208
+ training = TrainingConfig(
209
+ num_epochs=num_epochs,
210
+ batch_size=batch_size,
211
+ logger=None if logger == "none" else logger,
212
+ )
213
+
214
+ # create configuration
215
+ configuration = Configuration(
216
+ experiment_name=experiment_name,
217
+ algorithm_config=algorithm_config,
218
+ data_config=data,
219
+ training_config=training,
220
+ )
221
+
222
+ return configuration
223
+
224
+
225
+ # TODO reconsider naming once we officially support LVAE approaches
19
226
  def _create_supervised_configuration(
20
227
  algorithm: Literal["care", "n2n"],
21
228
  experiment_name: str,
22
229
  data_type: Literal["array", "tiff", "custom"],
23
230
  axes: str,
24
- patch_size: List[int],
231
+ patch_size: list[int],
25
232
  batch_size: int,
26
233
  num_epochs: int,
27
- use_augmentations: bool = True,
28
- independent_channels: bool = False,
234
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
235
+ independent_channels: bool = True,
29
236
  loss: Literal["mae", "mse"] = "mae",
30
237
  n_channels_in: int = 1,
31
238
  n_channels_out: int = 1,
32
239
  logger: Literal["wandb", "tensorboard", "none"] = "none",
33
- model_kwargs: Optional[dict] = None,
240
+ model_params: Optional[dict] = None,
241
+ dataloader_params: Optional[dict] = None,
34
242
  ) -> Configuration:
35
243
  """
36
244
  Create a configuration for training CARE or Noise2Noise.
@@ -51,8 +259,10 @@ def _create_supervised_configuration(
51
259
  Batch size.
52
260
  num_epochs : int
53
261
  Number of epochs.
54
- use_augmentations : bool, optional
55
- Whether to use augmentations, by default True.
262
+ augmentations : list of transforms, default=None
263
+ List of transforms to apply, either both or one of XYFlipModel and
264
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
265
+ and XYRandomRotate90 (in XY) to the images.
56
266
  independent_channels : bool, optional
57
267
  Whether to train all channels independently, by default False.
58
268
  loss : Literal["mae", "mse"], optional
@@ -63,8 +273,10 @@ def _create_supervised_configuration(
63
273
  Number of channels out, by default 1.
64
274
  logger : Literal["wandb", "tensorboard", "none"], optional
65
275
  Logger to use, by default "none".
66
- model_kwargs : dict, optional
276
+ model_params : dict, optional
67
277
  UNetModel parameters, by default {}.
278
+ dataloader_params : dict, optional
279
+ Parameters for the dataloader, see PyTorch notes, by default None.
68
280
 
69
281
  Returns
70
282
  -------
@@ -83,80 +295,43 @@ def _create_supervised_configuration(
83
295
  f"(got {n_channels_in} channels)."
84
296
  )
85
297
 
86
- # model
87
- if model_kwargs is None:
88
- model_kwargs = {}
89
- model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
90
- model_kwargs["in_channels"] = n_channels_in
91
- model_kwargs["num_classes"] = n_channels_out
92
- model_kwargs["independent_channels"] = independent_channels
93
-
94
- unet_model = UNetModel(
95
- architecture=SupportedArchitecture.UNET.value,
96
- **model_kwargs,
97
- )
98
-
99
- # algorithm model
100
- algorithm = AlgorithmConfig(
101
- algorithm=algorithm,
102
- loss=loss,
103
- model=unet_model,
104
- )
105
-
106
298
  # augmentations
107
- if use_augmentations:
108
- transforms: List[Dict[str, Any]] = [
109
- {
110
- "name": SupportedTransform.XY_FLIP.value,
111
- },
112
- {
113
- "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
114
- },
115
- ]
116
- else:
117
- transforms = []
299
+ transform_list = _list_augmentations(augmentations)
118
300
 
119
- # data model
120
- data = DataConfig(
301
+ return _create_configuration(
302
+ algorithm=algorithm,
303
+ experiment_name=experiment_name,
121
304
  data_type=data_type,
122
305
  axes=axes,
123
306
  patch_size=patch_size,
124
307
  batch_size=batch_size,
125
- transforms=transforms,
126
- )
127
-
128
- # training model
129
- training = TrainingConfig(
130
308
  num_epochs=num_epochs,
131
- batch_size=batch_size,
132
- logger=None if logger == "none" else logger,
133
- )
134
-
135
- # create configuration
136
- configuration = Configuration(
137
- experiment_name=experiment_name,
138
- algorithm_config=algorithm,
139
- data_config=data,
140
- training_config=training,
309
+ augmentations=transform_list,
310
+ independent_channels=independent_channels,
311
+ loss=loss,
312
+ n_channels_in=n_channels_in,
313
+ n_channels_out=n_channels_out,
314
+ logger=logger,
315
+ model_params=model_params,
316
+ dataloader_params=dataloader_params,
141
317
  )
142
318
 
143
- return configuration
144
-
145
319
 
146
320
  def create_care_configuration(
147
321
  experiment_name: str,
148
322
  data_type: Literal["array", "tiff", "custom"],
149
323
  axes: str,
150
- patch_size: List[int],
324
+ patch_size: list[int],
151
325
  batch_size: int,
152
326
  num_epochs: int,
153
- use_augmentations: bool = True,
154
- independent_channels: bool = False,
327
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
328
+ independent_channels: bool = True,
155
329
  loss: Literal["mae", "mse"] = "mae",
156
330
  n_channels_in: int = 1,
157
331
  n_channels_out: int = -1,
158
332
  logger: Literal["wandb", "tensorboard", "none"] = "none",
159
- model_kwargs: Optional[dict] = None,
333
+ model_params: Optional[dict] = None,
334
+ dataloader_params: Optional[dict] = None,
160
335
  ) -> Configuration:
161
336
  """
162
337
  Create a configuration for training CARE.
@@ -174,8 +349,10 @@ def create_care_configuration(
174
349
  By default, all channels are trained together. To train all channels independently,
175
350
  set `independent_channels` to True.
176
351
 
177
- By setting `use_augmentations` to False, the only transformation applied will be
178
- normalization.
352
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
353
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
354
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
355
+ disable the transforms, simply pass an empty list.
179
356
 
180
357
  Parameters
181
358
  ----------
@@ -191,8 +368,10 @@ def create_care_configuration(
191
368
  Batch size.
192
369
  num_epochs : int
193
370
  Number of epochs.
194
- use_augmentations : bool, optional
195
- Whether to use augmentations, by default True.
371
+ augmentations : list of transforms, default=None
372
+ List of transforms to apply, either both or one of XYFlipModel and
373
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
374
+ and XYRandomRotate90 (in XY) to the images.
196
375
  independent_channels : bool, optional
197
376
  Whether to train all channels independently, by default False.
198
377
  loss : Literal["mae", "mse"], optional
@@ -203,13 +382,82 @@ def create_care_configuration(
203
382
  Number of channels out, by default -1.
204
383
  logger : Literal["wandb", "tensorboard", "none"], optional
205
384
  Logger to use, by default "none".
206
- model_kwargs : dict, optional
207
- UNetModel parameters, by default {}.
385
+ model_params : dict, optional
386
+ UNetModel parameters, by default None.
387
+ dataloader_params : dict, optional
388
+ Parameters for the dataloader, see PyTorch notes, by default None.
208
389
 
209
390
  Returns
210
391
  -------
211
392
  Configuration
212
393
  Configuration for training CARE.
394
+
395
+ Examples
396
+ --------
397
+ Minimum example:
398
+ >>> config = create_care_configuration(
399
+ ... experiment_name="care_experiment",
400
+ ... data_type="array",
401
+ ... axes="YX",
402
+ ... patch_size=[64, 64],
403
+ ... batch_size=32,
404
+ ... num_epochs=100
405
+ ... )
406
+
407
+ To disable transforms, simply set `augmentations` to an empty list:
408
+ >>> config = create_care_configuration(
409
+ ... experiment_name="care_experiment",
410
+ ... data_type="array",
411
+ ... axes="YX",
412
+ ... patch_size=[64, 64],
413
+ ... batch_size=32,
414
+ ... num_epochs=100,
415
+ ... augmentations=[]
416
+ ... )
417
+
418
+ A list of transforms can be passed to the `augmentations` parameter to replace the
419
+ default augmentations:
420
+ >>> from careamics.config.transformations import XYFlipModel
421
+ >>> config = create_care_configuration(
422
+ ... experiment_name="care_experiment",
423
+ ... data_type="array",
424
+ ... axes="YX",
425
+ ... patch_size=[64, 64],
426
+ ... batch_size=32,
427
+ ... num_epochs=100,
428
+ ... augmentations=[
429
+ ... # No rotation and only Y flipping
430
+ ... XYFlipModel(flip_x = False, flip_y = True)
431
+ ... ]
432
+ ... )
433
+
434
+ If you are training multiple channels they will be trained independently by default,
435
+ you simply need to specify the number of channels input (and optionally, the number
436
+ of channels output):
437
+ >>> config = create_care_configuration(
438
+ ... experiment_name="care_experiment",
439
+ ... data_type="array",
440
+ ... axes="YXC", # channels must be in the axes
441
+ ... patch_size=[64, 64],
442
+ ... batch_size=32,
443
+ ... num_epochs=100,
444
+ ... n_channels_in=3, # number of input channels
445
+ ... n_channels_out=1 # if applicable
446
+ ... )
447
+
448
+ If instead you want to train multiple channels together, you need to turn off the
449
+ `independent_channels` parameter:
450
+ >>> config = create_care_configuration(
451
+ ... experiment_name="care_experiment",
452
+ ... data_type="array",
453
+ ... axes="YXC", # channels must be in the axes
454
+ ... patch_size=[64, 64],
455
+ ... batch_size=32,
456
+ ... num_epochs=100,
457
+ ... independent_channels=False,
458
+ ... n_channels_in=3,
459
+ ... n_channels_out=1 # if applicable
460
+ ... )
213
461
  """
214
462
  if n_channels_out == -1:
215
463
  n_channels_out = n_channels_in
@@ -222,13 +470,14 @@ def create_care_configuration(
222
470
  patch_size=patch_size,
223
471
  batch_size=batch_size,
224
472
  num_epochs=num_epochs,
225
- use_augmentations=use_augmentations,
473
+ augmentations=augmentations,
226
474
  independent_channels=independent_channels,
227
475
  loss=loss,
228
476
  n_channels_in=n_channels_in,
229
477
  n_channels_out=n_channels_out,
230
478
  logger=logger,
231
- model_kwargs=model_kwargs,
479
+ model_params=model_params,
480
+ dataloader_params=dataloader_params,
232
481
  )
233
482
 
234
483
 
@@ -236,16 +485,17 @@ def create_n2n_configuration(
236
485
  experiment_name: str,
237
486
  data_type: Literal["array", "tiff", "custom"],
238
487
  axes: str,
239
- patch_size: List[int],
488
+ patch_size: list[int],
240
489
  batch_size: int,
241
490
  num_epochs: int,
242
- use_augmentations: bool = True,
243
- independent_channels: bool = False,
491
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
492
+ independent_channels: bool = True,
244
493
  loss: Literal["mae", "mse"] = "mae",
245
494
  n_channels_in: int = 1,
246
495
  n_channels_out: int = -1,
247
496
  logger: Literal["wandb", "tensorboard", "none"] = "none",
248
- model_kwargs: Optional[dict] = None,
497
+ model_params: Optional[dict] = None,
498
+ dataloader_params: Optional[dict] = None,
249
499
  ) -> Configuration:
250
500
  """
251
501
  Create a configuration for training Noise2Noise.
@@ -263,8 +513,10 @@ def create_n2n_configuration(
263
513
  By default, all channels are trained together. To train all channels independently,
264
514
  set `independent_channels` to True.
265
515
 
266
- By setting `use_augmentations` to False, the only transformation applied will be
267
- normalization.
516
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
517
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
518
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
519
+ disable the transforms, simply pass an empty list.
268
520
 
269
521
  Parameters
270
522
  ----------
@@ -280,8 +532,10 @@ def create_n2n_configuration(
280
532
  Batch size.
281
533
  num_epochs : int
282
534
  Number of epochs.
283
- use_augmentations : bool, optional
284
- Whether to use augmentations, by default True.
535
+ augmentations : list of transforms, default=None
536
+ List of transforms to apply, either both or one of XYFlipModel and
537
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
538
+ and XYRandomRotate90 (in XY) to the images.
285
539
  independent_channels : bool, optional
286
540
  Whether to train all channels independently, by default False.
287
541
  loss : Literal["mae", "mse"], optional
@@ -292,13 +546,82 @@ def create_n2n_configuration(
292
546
  Number of channels out, by default -1.
293
547
  logger : Literal["wandb", "tensorboard", "none"], optional
294
548
  Logger to use, by default "none".
295
- model_kwargs : dict, optional
549
+ model_params : dict, optional
296
550
  UNetModel parameters, by default {}.
551
+ dataloader_params : dict, optional
552
+ Parameters for the dataloader, see PyTorch notes, by default None.
297
553
 
298
554
  Returns
299
555
  -------
300
556
  Configuration
301
557
  Configuration for training Noise2Noise.
558
+
559
+ Examples
560
+ --------
561
+ Minimum example:
562
+ >>> config = create_n2n_configuration(
563
+ ... experiment_name="n2n_experiment",
564
+ ... data_type="array",
565
+ ... axes="YX",
566
+ ... patch_size=[64, 64],
567
+ ... batch_size=32,
568
+ ... num_epochs=100
569
+ ... )
570
+
571
+ To disable transforms, simply set `augmentations` to an empty list:
572
+ >>> config = create_n2n_configuration(
573
+ ... experiment_name="n2n_experiment",
574
+ ... data_type="array",
575
+ ... axes="YX",
576
+ ... patch_size=[64, 64],
577
+ ... batch_size=32,
578
+ ... num_epochs=100,
579
+ ... augmentations=[]
580
+ ... )
581
+
582
+ A list of transforms can be passed to the `augmentations` parameter to replace the
583
+ default augmentations:
584
+ >>> from careamics.config.transformations import XYFlipModel
585
+ >>> config = create_n2n_configuration(
586
+ ... experiment_name="n2n_experiment",
587
+ ... data_type="array",
588
+ ... axes="YX",
589
+ ... patch_size=[64, 64],
590
+ ... batch_size=32,
591
+ ... num_epochs=100,
592
+ ... augmentations=[
593
+ ... # No rotation and only Y flipping
594
+ ... XYFlipModel(flip_x = False, flip_y = True)
595
+ ... ]
596
+ ... )
597
+
598
+ If you are training multiple channels they will be trained independently by default,
599
+ you simply need to specify the number of channels input (and optionally, the number
600
+ of channels output):
601
+ >>> config = create_n2n_configuration(
602
+ ... experiment_name="n2n_experiment",
603
+ ... data_type="array",
604
+ ... axes="YXC", # channels must be in the axes
605
+ ... patch_size=[64, 64],
606
+ ... batch_size=32,
607
+ ... num_epochs=100,
608
+ ... n_channels_in=3, # number of input channels
609
+ ... n_channels_out=1 # if applicable
610
+ ... )
611
+
612
+ If instead you want to train multiple channels together, you need to turn off the
613
+ `independent_channels` parameter:
614
+ >>> config = create_n2n_configuration(
615
+ ... experiment_name="n2n_experiment",
616
+ ... data_type="array",
617
+ ... axes="YXC", # channels must be in the axes
618
+ ... patch_size=[64, 64],
619
+ ... batch_size=32,
620
+ ... num_epochs=100,
621
+ ... independent_channels=False,
622
+ ... n_channels_in=3,
623
+ ... n_channels_out=1 # if applicable
624
+ ... )
302
625
  """
303
626
  if n_channels_out == -1:
304
627
  n_channels_out = n_channels_in
@@ -311,13 +634,14 @@ def create_n2n_configuration(
311
634
  patch_size=patch_size,
312
635
  batch_size=batch_size,
313
636
  num_epochs=num_epochs,
314
- use_augmentations=use_augmentations,
637
+ augmentations=augmentations,
315
638
  independent_channels=independent_channels,
316
639
  loss=loss,
317
640
  n_channels_in=n_channels_in,
318
641
  n_channels_out=n_channels_out,
319
642
  logger=logger,
320
- model_kwargs=model_kwargs,
643
+ model_params=model_params,
644
+ dataloader_params=dataloader_params,
321
645
  )
322
646
 
323
647
 
@@ -325,10 +649,10 @@ def create_n2v_configuration(
325
649
  experiment_name: str,
326
650
  data_type: Literal["array", "tiff", "custom"],
327
651
  axes: str,
328
- patch_size: List[int],
652
+ patch_size: list[int],
329
653
  batch_size: int,
330
654
  num_epochs: int,
331
- use_augmentations: bool = True,
655
+ augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
332
656
  independent_channels: bool = True,
333
657
  use_n2v2: bool = False,
334
658
  n_channels: int = 1,
@@ -337,7 +661,8 @@ def create_n2v_configuration(
337
661
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
338
662
  struct_n2v_span: int = 5,
339
663
  logger: Literal["wandb", "tensorboard", "none"] = "none",
340
- model_kwargs: Optional[dict] = None,
664
+ model_params: Optional[dict] = None,
665
+ dataloader_params: Optional[dict] = None,
341
666
  ) -> Configuration:
342
667
  """
343
668
  Create a configuration for training Noise2Void.
@@ -360,16 +685,22 @@ def create_n2v_configuration(
360
685
  By default, all channels are trained independently. To train all channels together,
361
686
  set `independent_channels` to False.
362
687
 
363
- By setting `use_augmentations` to False, the only transformations applied will be
364
- normalization and N2V manipulation.
688
+ By default, the transformations applied are a random flip along X or Y, and a random
689
+ 90 degrees rotation in the XY plane. Normalization is always applied, as well as the
690
+ N2V manipulation.
691
+
692
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
693
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
694
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
695
+ disable the transforms, simply pass an empty list.
365
696
 
366
697
  The `roi_size` parameter specifies the size of the area around each pixel that will
367
698
  be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
368
699
  pixels per patch will be manipulated.
369
700
 
370
- The parameters of the UNet can be specified in the `model_kwargs` (passed as a
701
+ The parameters of the UNet can be specified in the `model_params` (passed as a
371
702
  parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
372
- corresponding parameters passed in `model_kwargs`.
703
+ corresponding parameters passed in `model_params`.
373
704
 
374
705
  If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
375
706
  will be applied to each manipulated pixel.
@@ -388,8 +719,10 @@ def create_n2v_configuration(
388
719
  Batch size.
389
720
  num_epochs : int
390
721
  Number of epochs.
391
- use_augmentations : bool, optional
392
- Whether to use augmentations, by default True.
722
+ augmentations : list of transforms, default=None
723
+ List of transforms to apply, either both or one of XYFlipModel and
724
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
725
+ and XYRandomRotate90 (in XY) to the images.
393
726
  independent_channels : bool, optional
394
727
  Whether to train all channels together, by default True.
395
728
  use_n2v2 : bool, optional
@@ -406,8 +739,10 @@ def create_n2v_configuration(
406
739
  Span of the structN2V mask, by default 5.
407
740
  logger : Literal["wandb", "tensorboard", "none"], optional
408
741
  Logger to use, by default "none".
409
- model_kwargs : dict, optional
410
- UNetModel parameters, by default {}.
742
+ model_params : dict, optional
743
+ UNetModel parameters, by default None.
744
+ dataloader_params : dict, optional
745
+ Parameters for the dataloader, see PyTorch notes, by default None.
411
746
 
412
747
  Returns
413
748
  -------
@@ -426,6 +761,32 @@ def create_n2v_configuration(
426
761
  ... num_epochs=100
427
762
  ... )
428
763
 
764
+ To disable transforms, simply set `augmentations` to an empty list:
765
+ >>> config = create_n2v_configuration(
766
+ ... experiment_name="n2v_experiment",
767
+ ... data_type="array",
768
+ ... axes="YX",
769
+ ... patch_size=[64, 64],
770
+ ... batch_size=32,
771
+ ... num_epochs=100,
772
+ ... augmentations=[]
773
+ ... )
774
+
775
+ A list of transforms can be passed to the `augmentations` parameter:
776
+ >>> from careamics.config.transformations import XYFlipModel
777
+ >>> config = create_n2v_configuration(
778
+ ... experiment_name="n2v_experiment",
779
+ ... data_type="array",
780
+ ... axes="YX",
781
+ ... patch_size=[64, 64],
782
+ ... batch_size=32,
783
+ ... num_epochs=100,
784
+ ... augmentations=[
785
+ ... # No rotation and only Y flipping
786
+ ... XYFlipModel(flip_x = False, flip_y = True)
787
+ ... ]
788
+ ... )
789
+
429
790
  To use N2V2, simply pass the `use_n2v2` parameter:
430
791
  >>> config = create_n2v_configuration(
431
792
  ... experiment_name="n2v2_experiment",
@@ -450,8 +811,8 @@ def create_n2v_configuration(
450
811
  ... struct_n2v_span=7
451
812
  ... )
452
813
 
453
- If you are training multiple channels independently, then you need to specify the
454
- number of channels:
814
+ If you are training multiple channels they will be trained independently by default,
815
+ you simply need to specify the number of channels:
455
816
  >>> config = create_n2v_configuration(
456
817
  ... experiment_name="n2v_experiment",
457
818
  ... data_type="array",
@@ -474,18 +835,6 @@ def create_n2v_configuration(
474
835
  ... independent_channels=False,
475
836
  ... n_channels=3
476
837
  ... )
477
-
478
- To turn off the augmentations, except normalization and N2V manipulation, use the
479
- relevant keyword argument:
480
- >>> config = create_n2v_configuration(
481
- ... experiment_name="n2v_experiment",
482
- ... data_type="array",
483
- ... axes="YX",
484
- ... patch_size=[64, 64],
485
- ... batch_size=32,
486
- ... num_epochs=100,
487
- ... use_augmentations=False
488
- ... )
489
838
  """
490
839
  # if there are channels, we need to specify their number
491
840
  if "C" in axes and n_channels == 1:
@@ -499,77 +848,39 @@ def create_n2v_configuration(
499
848
  f"(got {n_channels} channel)."
500
849
  )
501
850
 
502
- # model
503
- if model_kwargs is None:
504
- model_kwargs = {}
505
- model_kwargs["n2v2"] = use_n2v2
506
- model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
507
- model_kwargs["in_channels"] = n_channels
508
- model_kwargs["num_classes"] = n_channels
509
- model_kwargs["independent_channels"] = independent_channels
510
-
511
- unet_model = UNetModel(
512
- architecture=SupportedArchitecture.UNET.value,
513
- **model_kwargs,
514
- )
515
-
516
- # algorithm model
517
- algorithm = AlgorithmConfig(
518
- algorithm=SupportedAlgorithm.N2V.value,
519
- loss=SupportedLoss.N2V.value,
520
- model=unet_model,
521
- )
522
-
523
851
  # augmentations
524
- if use_augmentations:
525
- transforms: List[Dict[str, Any]] = [
526
- {
527
- "name": SupportedTransform.XY_FLIP.value,
528
- },
529
- {
530
- "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
531
- },
532
- ]
533
- else:
534
- transforms = []
852
+ transform_list = _list_augmentations(augmentations)
535
853
 
536
- # n2v2 and structn2v
537
- nv2_transform = {
538
- "name": SupportedTransform.N2V_MANIPULATE.value,
539
- "strategy": (
854
+ # create the N2VManipulate transform using the supplied parameters
855
+ n2v_transform = N2VManipulateModel(
856
+ name=SupportedTransform.N2V_MANIPULATE.value,
857
+ strategy=(
540
858
  SupportedPixelManipulation.MEDIAN.value
541
859
  if use_n2v2
542
860
  else SupportedPixelManipulation.UNIFORM.value
543
861
  ),
544
- "roi_size": roi_size,
545
- "masked_pixel_percentage": masked_pixel_percentage,
546
- "struct_mask_axis": struct_n2v_axis,
547
- "struct_mask_span": struct_n2v_span,
548
- }
549
- transforms.append(nv2_transform)
862
+ roi_size=roi_size,
863
+ masked_pixel_percentage=masked_pixel_percentage,
864
+ struct_mask_axis=struct_n2v_axis,
865
+ struct_mask_span=struct_n2v_span,
866
+ )
867
+ transform_list.append(n2v_transform)
550
868
 
551
- # data model
552
- data = DataConfig(
869
+ return _create_configuration(
870
+ algorithm="n2v",
871
+ experiment_name=experiment_name,
553
872
  data_type=data_type,
554
873
  axes=axes,
555
874
  patch_size=patch_size,
556
875
  batch_size=batch_size,
557
- transforms=transforms,
558
- )
559
-
560
- # training model
561
- training = TrainingConfig(
562
876
  num_epochs=num_epochs,
563
- batch_size=batch_size,
564
- logger=None if logger == "none" else logger,
565
- )
566
-
567
- # create configuration
568
- configuration = Configuration(
569
- experiment_name=experiment_name,
570
- algorithm_config=algorithm,
571
- data_config=data,
572
- training_config=training,
877
+ augmentations=transform_list,
878
+ independent_channels=independent_channels,
879
+ loss="n2v",
880
+ use_n2v2=use_n2v2,
881
+ n_channels_in=n_channels,
882
+ n_channels_out=n_channels,
883
+ logger=logger,
884
+ model_params=model_params,
885
+ dataloader_params=dataloader_params,
573
886
  )
574
-
575
- return configuration