careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (66) hide show
  1. careamics/__init__.py +8 -6
  2. careamics/careamist.py +30 -29
  3. careamics/config/__init__.py +12 -9
  4. careamics/config/algorithm_model.py +5 -5
  5. careamics/config/architectures/unet_model.py +1 -0
  6. careamics/config/callback_model.py +1 -0
  7. careamics/config/configuration_example.py +87 -0
  8. careamics/config/configuration_factory.py +285 -78
  9. careamics/config/configuration_model.py +22 -23
  10. careamics/config/data_model.py +62 -160
  11. careamics/config/inference_model.py +20 -21
  12. careamics/config/references/algorithm_descriptions.py +1 -0
  13. careamics/config/references/references.py +1 -0
  14. careamics/config/support/supported_extraction_strategies.py +1 -0
  15. careamics/config/support/supported_optimizers.py +3 -3
  16. careamics/config/training_model.py +2 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +2 -1
  18. careamics/config/transformations/nd_flip_model.py +7 -12
  19. careamics/config/transformations/normalize_model.py +2 -1
  20. careamics/config/transformations/transform_model.py +1 -0
  21. careamics/config/transformations/xy_random_rotate90_model.py +7 -9
  22. careamics/config/validators/validator_utils.py +1 -0
  23. careamics/conftest.py +1 -0
  24. careamics/dataset/dataset_utils/__init__.py +0 -1
  25. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  26. careamics/dataset/in_memory_dataset.py +17 -48
  27. careamics/dataset/iterable_dataset.py +16 -71
  28. careamics/dataset/patching/__init__.py +0 -7
  29. careamics/dataset/patching/patching.py +1 -0
  30. careamics/dataset/patching/sequential_patching.py +6 -6
  31. careamics/dataset/patching/tiled_patching.py +10 -6
  32. careamics/lightning_datamodule.py +123 -49
  33. careamics/lightning_module.py +7 -7
  34. careamics/lightning_prediction_datamodule.py +59 -48
  35. careamics/losses/__init__.py +0 -1
  36. careamics/losses/loss_factory.py +1 -0
  37. careamics/model_io/__init__.py +0 -1
  38. careamics/model_io/bioimage/_readme_factory.py +2 -1
  39. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  40. careamics/model_io/bioimage/model_description.py +4 -3
  41. careamics/model_io/bmz_io.py +8 -7
  42. careamics/model_io/model_io_utils.py +4 -4
  43. careamics/models/layers.py +1 -0
  44. careamics/models/model_factory.py +1 -0
  45. careamics/models/unet.py +91 -17
  46. careamics/prediction/stitch_prediction.py +1 -0
  47. careamics/transforms/__init__.py +2 -23
  48. careamics/transforms/compose.py +98 -0
  49. careamics/transforms/n2v_manipulate.py +18 -23
  50. careamics/transforms/nd_flip.py +38 -64
  51. careamics/transforms/normalize.py +45 -34
  52. careamics/transforms/pixel_manipulation.py +2 -2
  53. careamics/transforms/transform.py +33 -0
  54. careamics/transforms/tta.py +2 -2
  55. careamics/transforms/xy_random_rotate90.py +41 -68
  56. careamics/utils/__init__.py +0 -1
  57. careamics/utils/context.py +1 -0
  58. careamics/utils/logging.py +1 -0
  59. careamics/utils/metrics.py +1 -0
  60. careamics/utils/torch_utils.py +1 -0
  61. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  62. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  63. careamics/dataset/patching/patch_transform.py +0 -44
  64. careamics-0.1.0rc3.dist-info/RECORD +0 -109
  65. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  66. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -2,13 +2,11 @@
2
2
 
3
3
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
4
 
5
- from albumentations import Compose
6
-
7
- from .algorithm_model import AlgorithmModel
5
+ from .algorithm_model import AlgorithmConfig
8
6
  from .architectures import UNetModel
9
7
  from .configuration_model import Configuration
10
- from .data_model import DataModel
11
- from .inference_model import InferenceModel
8
+ from .data_model import DataConfig
9
+ from .inference_model import InferenceConfig
12
10
  from .support import (
13
11
  SupportedAlgorithm,
14
12
  SupportedArchitecture,
@@ -16,10 +14,11 @@ from .support import (
16
14
  SupportedPixelManipulation,
17
15
  SupportedTransform,
18
16
  )
19
- from .training_model import TrainingModel
17
+ from .training_model import TrainingConfig
20
18
 
21
19
 
22
- def create_n2n_configuration(
20
+ def _create_supervised_configuration(
21
+ algorithm: Literal["care", "n2n"],
23
22
  experiment_name: str,
24
23
  data_type: Literal["array", "tiff", "custom"],
25
24
  axes: str,
@@ -27,28 +26,20 @@ def create_n2n_configuration(
27
26
  batch_size: int,
28
27
  num_epochs: int,
29
28
  use_augmentations: bool = True,
30
- use_n2v2: bool = False,
31
- n_channels: int = 1,
29
+ independent_channels: bool = False,
30
+ loss: Literal["mae", "mse"] = "mae",
31
+ n_channels_in: int = 1,
32
+ n_channels_out: int = 1,
32
33
  logger: Literal["wandb", "tensorboard", "none"] = "none",
33
34
  model_kwargs: Optional[dict] = None,
34
35
  ) -> Configuration:
35
36
  """
36
- Create a configuration for training N2V.
37
-
38
- If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
39
- 2.
40
-
41
- By setting `use_augmentations` to False, the only transformation applied will be
42
- normalization and N2V manipulation.
43
-
44
- The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed
45
- in `model_kwargs`.
46
-
47
- If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
48
- will be applied to each manipulated pixel.
37
+ Create a configuration for training CARE or Noise2Noise.
49
38
 
50
39
  Parameters
51
40
  ----------
41
+ algorithm : Literal["care", "n2n"]
42
+ Algorithm to use.
52
43
  experiment_name : str
53
44
  Name of the experiment.
54
45
  data_type : Literal["array", "tiff", "custom"]
@@ -63,18 +54,14 @@ def create_n2n_configuration(
63
54
  Number of epochs.
64
55
  use_augmentations : bool, optional
65
56
  Whether to use augmentations, by default True.
66
- use_n2v2 : bool, optional
67
- Whether to use N2V2, by default False.
68
- n_channels : int, optional
69
- Number of channels (in and out), by default 1.
70
- roi_size : int, optional
71
- N2V pixel manipulation area, by default 11.
72
- masked_pixel_percentage : float, optional
73
- Percentage of pixels masked in each patch, by default 0.2.
74
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
75
- Axis along which to apply structN2V mask, by default "none".
76
- struct_n2v_span : int, optional
77
- Span of the structN2V mask, by default 5.
57
+ independent_channels : bool, optional
58
+ Whether to train all channels independently, by default False.
59
+ loss : Literal["mae", "mse"], optional
60
+ Loss function to use, by default "mae".
61
+ n_channels_in : int, optional
62
+ Number of channels in, by default 1.
63
+ n_channels_out : int, optional
64
+ Number of channels out, by default 1.
78
65
  logger : Literal["wandb", "tensorboard", "none"], optional
79
66
  Logger to use, by default "none".
80
67
  model_kwargs : dict, optional
@@ -83,15 +70,27 @@ def create_n2n_configuration(
83
70
  Returns
84
71
  -------
85
72
  Configuration
86
- Configuration for training N2V.
73
+ Configuration for training CARE or Noise2Noise.
87
74
  """
75
+ # if there are channels, we need to specify their number
76
+ if "C" in axes and n_channels_in == 1:
77
+ raise ValueError(
78
+ f"Number of channels in must be specified when using channels "
79
+ f"(got {n_channels_in} channel)."
80
+ )
81
+ elif "C" not in axes and n_channels_in > 1:
82
+ raise ValueError(
83
+ f"C is not present in the axes, but number of channels is specified "
84
+ f"(got {n_channels_in} channels)."
85
+ )
86
+
88
87
  # model
89
88
  if model_kwargs is None:
90
89
  model_kwargs = {}
91
- model_kwargs["n2v2"] = use_n2v2
92
90
  model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
93
- model_kwargs["in_channels"] = n_channels
94
- model_kwargs["num_classes"] = n_channels
91
+ model_kwargs["in_channels"] = n_channels_in
92
+ model_kwargs["num_classes"] = n_channels_out
93
+ model_kwargs["independent_channels"] = independent_channels
95
94
 
96
95
  unet_model = UNetModel(
97
96
  architecture=SupportedArchitecture.UNET.value,
@@ -99,9 +98,9 @@ def create_n2n_configuration(
99
98
  )
100
99
 
101
100
  # algorithm model
102
- algorithm = AlgorithmModel(
103
- algorithm=SupportedAlgorithm.N2V.value,
104
- loss=SupportedLoss.N2V.value,
101
+ algorithm = AlgorithmConfig(
102
+ algorithm=algorithm,
103
+ loss=loss,
105
104
  model=unet_model,
106
105
  )
107
106
 
@@ -126,7 +125,7 @@ def create_n2n_configuration(
126
125
  ]
127
126
 
128
127
  # data model
129
- data = DataModel(
128
+ data = DataConfig(
130
129
  data_type=data_type,
131
130
  axes=axes,
132
131
  patch_size=patch_size,
@@ -135,7 +134,7 @@ def create_n2n_configuration(
135
134
  )
136
135
 
137
136
  # training model
138
- training = TrainingModel(
137
+ training = TrainingConfig(
139
138
  num_epochs=num_epochs,
140
139
  batch_size=batch_size,
141
140
  logger=None if logger == "none" else logger,
@@ -152,6 +151,175 @@ def create_n2n_configuration(
152
151
  return configuration
153
152
 
154
153
 
154
+ def create_care_configuration(
155
+ experiment_name: str,
156
+ data_type: Literal["array", "tiff", "custom"],
157
+ axes: str,
158
+ patch_size: List[int],
159
+ batch_size: int,
160
+ num_epochs: int,
161
+ use_augmentations: bool = True,
162
+ independent_channels: bool = False,
163
+ loss: Literal["mae", "mse"] = "mae",
164
+ n_channels_in: int = 1,
165
+ n_channels_out: int = -1,
166
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
167
+ model_kwargs: Optional[dict] = None,
168
+ ) -> Configuration:
169
+ """
170
+ Create a configuration for training CARE.
171
+
172
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
173
+ 2.
174
+
175
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
176
+ channels. Likewise, if you set the number of channels, then "C" must be present in
177
+ `axes`.
178
+
179
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
180
+ not specified, it will be assumed to be equal to `n_channels_in`.
181
+
182
+ By default, all channels are trained together. To train all channels independently,
183
+ set `independent_channels` to True.
184
+
185
+ By setting `use_augmentations` to False, the only transformation applied will be
186
+ normalization.
187
+
188
+ Parameters
189
+ ----------
190
+ experiment_name : str
191
+ Name of the experiment.
192
+ data_type : Literal["array", "tiff", "custom"]
193
+ Type of the data.
194
+ axes : str
195
+ Axes of the data (e.g. SYX).
196
+ patch_size : List[int]
197
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
198
+ batch_size : int
199
+ Batch size.
200
+ num_epochs : int
201
+ Number of epochs.
202
+ use_augmentations : bool, optional
203
+ Whether to use augmentations, by default True.
204
+ independent_channels : bool, optional
205
+ Whether to train all channels independently, by default False.
206
+ loss : Literal["mae", "mse"], optional
207
+ Loss function to use, by default "mae".
208
+ n_channels_in : int, optional
209
+ Number of channels in, by default 1.
210
+ n_channels_out : int, optional
211
+ Number of channels out, by default -1.
212
+ logger : Literal["wandb", "tensorboard", "none"], optional
213
+ Logger to use, by default "none".
214
+ model_kwargs : dict, optional
215
+ UNetModel parameters, by default {}.
216
+
217
+ Returns
218
+ -------
219
+ Configuration
220
+ Configuration for training CARE.
221
+ """
222
+ if n_channels_out == -1:
223
+ n_channels_out = n_channels_in
224
+
225
+ return _create_supervised_configuration(
226
+ algorithm="care",
227
+ experiment_name=experiment_name,
228
+ data_type=data_type,
229
+ axes=axes,
230
+ patch_size=patch_size,
231
+ batch_size=batch_size,
232
+ num_epochs=num_epochs,
233
+ use_augmentations=use_augmentations,
234
+ independent_channels=independent_channels,
235
+ loss=loss,
236
+ n_channels_in=n_channels_in,
237
+ n_channels_out=n_channels_out,
238
+ logger=logger,
239
+ model_kwargs=model_kwargs,
240
+ )
241
+
242
+
243
+ def create_n2n_configuration(
244
+ experiment_name: str,
245
+ data_type: Literal["array", "tiff", "custom"],
246
+ axes: str,
247
+ patch_size: List[int],
248
+ batch_size: int,
249
+ num_epochs: int,
250
+ use_augmentations: bool = True,
251
+ independent_channels: bool = False,
252
+ loss: Literal["mae", "mse"] = "mae",
253
+ n_channels: int = 1,
254
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
255
+ model_kwargs: Optional[dict] = None,
256
+ ) -> Configuration:
257
+ """
258
+ Create a configuration for training Noise2Noise.
259
+
260
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
261
+ 2.
262
+
263
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
264
+ channels. Likewise, if you set the number of channels, then "C" must be present in
265
+ `axes`.
266
+
267
+ By default, all channels are trained together. To train all channels independently,
268
+ set `independent_channels` to True.
269
+
270
+ By setting `use_augmentations` to False, the only transformation applied will be
271
+ normalization.
272
+
273
+ Parameters
274
+ ----------
275
+ experiment_name : str
276
+ Name of the experiment.
277
+ data_type : Literal["array", "tiff", "custom"]
278
+ Type of the data.
279
+ axes : str
280
+ Axes of the data (e.g. SYX).
281
+ patch_size : List[int]
282
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
283
+ batch_size : int
284
+ Batch size.
285
+ num_epochs : int
286
+ Number of epochs.
287
+ use_augmentations : bool, optional
288
+ Whether to use augmentations, by default True.
289
+ independent_channels : bool, optional
290
+ Whether to train all channels independently, by default False.
291
+ loss : Literal["mae", "mse"], optional
292
+ Loss function to use, by default "mae".
293
+ n_channels : int, optional
294
+ Number of channels (in and out), by default 1.
295
+ logger : Literal["wandb", "tensorboard", "none"], optional
296
+ Logger to use, by default "none".
297
+ model_kwargs : dict, optional
298
+ UNetModel parameters, by default {}.
299
+
300
+ Returns
301
+ -------
302
+ Configuration
303
+ Configuration for training Noise2Noise.
304
+ """
305
+ return _create_supervised_configuration(
306
+ algorithm="n2n",
307
+ experiment_name=experiment_name,
308
+ data_type=data_type,
309
+ axes=axes,
310
+ patch_size=patch_size,
311
+ batch_size=batch_size,
312
+ num_epochs=num_epochs,
313
+ use_augmentations=use_augmentations,
314
+ independent_channels=independent_channels,
315
+ loss=loss,
316
+ n_channels_in=n_channels,
317
+ n_channels_out=n_channels,
318
+ logger=logger,
319
+ model_kwargs=model_kwargs,
320
+ )
321
+
322
+
155
323
  def create_n2v_configuration(
156
324
  experiment_name: str,
157
325
  data_type: Literal["array", "tiff", "custom"],
@@ -160,8 +328,9 @@ def create_n2v_configuration(
160
328
  batch_size: int,
161
329
  num_epochs: int,
162
330
  use_augmentations: bool = True,
331
+ independent_channels: bool = True,
163
332
  use_n2v2: bool = False,
164
- n_channels: int = -1,
333
+ n_channels: int = 1,
165
334
  roi_size: int = 11,
166
335
  masked_pixel_percentage: float = 0.2,
167
336
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
@@ -170,7 +339,7 @@ def create_n2v_configuration(
170
339
  model_kwargs: Optional[dict] = None,
171
340
  ) -> Configuration:
172
341
  """
173
- Create a configuration for training N2V.
342
+ Create a configuration for training Noise2Void.
174
343
 
175
344
  N2V uses a UNet model to denoise images in a self-supervised manner. To use its
176
345
  variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
@@ -181,11 +350,14 @@ def create_n2v_configuration(
181
350
  or horizontal correlations are present in the noise; it applies an additional mask
182
351
  to the manipulated pixel neighbors.
183
352
 
353
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
354
+ 2.
355
+
184
356
  If "C" is present in `axes`, then you need to set `n_channels` to the number of
185
357
  channels.
186
358
 
187
- If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
188
- 2.
359
+ By default, all channels are trained independently. To train all channels together,
360
+ set `independent_channels` to False.
189
361
 
190
362
  By setting `use_augmentations` to False, the only transformations applied will be
191
363
  normalization and N2V manipulation.
@@ -217,10 +389,12 @@ def create_n2v_configuration(
217
389
  Number of epochs.
218
390
  use_augmentations : bool, optional
219
391
  Whether to use augmentations, by default True.
392
+ independent_channels : bool, optional
393
+ Whether to train all channels together, by default True.
220
394
  use_n2v2 : bool, optional
221
395
  Whether to use N2V2, by default False.
222
396
  n_channels : int, optional
223
- Number of channels (in and out), by default -1.
397
+ Number of channels (in and out), by default 1.
224
398
  roi_size : int, optional
225
399
  N2V pixel manipulation area, by default 11.
226
400
  masked_pixel_percentage : float, optional
@@ -275,8 +449,20 @@ def create_n2v_configuration(
275
449
  ... struct_n2v_span=7
276
450
  ... )
277
451
 
278
- If you are training multiple channels together, then you need to specify the number
279
- of channels:
452
+ If you are training multiple channels independently, then you need to specify the
453
+ number of channels:
454
+ >>> config = create_n2v_configuration(
455
+ ... experiment_name="n2v_experiment",
456
+ ... data_type="array",
457
+ ... axes="YXC",
458
+ ... patch_size=[64, 64],
459
+ ... batch_size=32,
460
+ ... num_epochs=100,
461
+ ... n_channels=3
462
+ ... )
463
+
464
+ If instead you want to train multiple channels together, you need to turn off the
465
+ `independent_channels` parameter:
280
466
  >>> config = create_n2v_configuration(
281
467
  ... experiment_name="n2v_experiment",
282
468
  ... data_type="array",
@@ -284,6 +470,7 @@ def create_n2v_configuration(
284
470
  ... patch_size=[64, 64],
285
471
  ... batch_size=32,
286
472
  ... num_epochs=100,
473
+ ... independent_channels=False,
287
474
  ... n_channels=3
288
475
  ... )
289
476
 
@@ -300,18 +487,16 @@ def create_n2v_configuration(
300
487
  ... )
301
488
  """
302
489
  # if there are channels, we need to specify their number
303
- if "C" in axes and n_channels == -1:
490
+ if "C" in axes and n_channels == 1:
304
491
  raise ValueError(
305
492
  f"Number of channels must be specified when using channels "
306
493
  f"(got {n_channels} channel)."
307
494
  )
308
- elif "C" not in axes and n_channels != -1:
495
+ elif "C" not in axes and n_channels > 1:
309
496
  raise ValueError(
310
497
  f"C is not present in the axes, but number of channels is specified "
311
498
  f"(got {n_channels} channel)."
312
499
  )
313
- elif n_channels == -1:
314
- n_channels = 1
315
500
 
316
501
  # model
317
502
  if model_kwargs is None:
@@ -320,6 +505,7 @@ def create_n2v_configuration(
320
505
  model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
321
506
  model_kwargs["in_channels"] = n_channels
322
507
  model_kwargs["num_classes"] = n_channels
508
+ model_kwargs["independent_channels"] = independent_channels
323
509
 
324
510
  unet_model = UNetModel(
325
511
  architecture=SupportedArchitecture.UNET.value,
@@ -327,7 +513,7 @@ def create_n2v_configuration(
327
513
  )
328
514
 
329
515
  # algorithm model
330
- algorithm = AlgorithmModel(
516
+ algorithm = AlgorithmConfig(
331
517
  algorithm=SupportedAlgorithm.N2V.value,
332
518
  loss=SupportedLoss.N2V.value,
333
519
  model=unet_model,
@@ -356,9 +542,11 @@ def create_n2v_configuration(
356
542
  # n2v2 and structn2v
357
543
  nv2_transform = {
358
544
  "name": SupportedTransform.N2V_MANIPULATE.value,
359
- "strategy": SupportedPixelManipulation.MEDIAN.value
360
- if use_n2v2
361
- else SupportedPixelManipulation.UNIFORM.value,
545
+ "strategy": (
546
+ SupportedPixelManipulation.MEDIAN.value
547
+ if use_n2v2
548
+ else SupportedPixelManipulation.UNIFORM.value
549
+ ),
362
550
  "roi_size": roi_size,
363
551
  "masked_pixel_percentage": masked_pixel_percentage,
364
552
  "struct_mask_axis": struct_n2v_axis,
@@ -367,7 +555,7 @@ def create_n2v_configuration(
367
555
  transforms.append(nv2_transform)
368
556
 
369
557
  # data model
370
- data = DataModel(
558
+ data = DataConfig(
371
559
  data_type=data_type,
372
560
  axes=axes,
373
561
  patch_size=patch_size,
@@ -376,7 +564,7 @@ def create_n2v_configuration(
376
564
  )
377
565
 
378
566
  # training model
379
- training = TrainingModel(
567
+ training = TrainingConfig(
380
568
  num_epochs=num_epochs,
381
569
  batch_size=batch_size,
382
570
  logger=None if logger == "none" else logger,
@@ -393,17 +581,16 @@ def create_n2v_configuration(
393
581
  return configuration
394
582
 
395
583
 
396
- # TODO add tests
397
584
  def create_inference_configuration(
398
- training_configuration: Configuration,
585
+ configuration: Configuration,
399
586
  tile_size: Optional[Tuple[int, ...]] = None,
400
587
  tile_overlap: Optional[Tuple[int, ...]] = None,
401
588
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
402
589
  axes: Optional[str] = None,
403
- transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
590
+ transforms: Optional[Union[List[Dict[str, Any]]]] = None,
404
591
  tta_transforms: bool = True,
405
592
  batch_size: Optional[int] = 1,
406
- ) -> InferenceModel:
593
+ ) -> InferenceConfig:
407
594
  """
408
595
  Create a configuration for inference with N2V.
409
596
 
@@ -412,8 +599,8 @@ def create_inference_configuration(
412
599
 
413
600
  Parameters
414
601
  ----------
415
- training_configuration : Configuration
416
- Configuration used for training.
602
+ configuration : Configuration
603
+ Global configuration.
417
604
  tile_size : Tuple[int, ...], optional
418
605
  Size of the tiles.
419
606
  tile_overlap : Tuple[int, ...], optional
@@ -422,7 +609,7 @@ def create_inference_configuration(
422
609
  Type of the data, by default "tiff".
423
610
  axes : str, optional
424
611
  Axes of the data, by default "YX".
425
- transforms : List[Dict[str, Any]] or Compose, optional
612
+ transforms : List[Dict[str, Any]], optional
426
613
  Transformations to apply to the data, by default None.
427
614
  tta_transforms : bool, optional
428
615
  Whether to apply test-time augmentations, by default True.
@@ -432,14 +619,12 @@ def create_inference_configuration(
432
619
  Returns
433
620
  -------
434
621
  InferenceConfiguration
435
- Configuration for inference with N2V.
622
+ Configuration used to configure CAREamicsPredictData.
436
623
  """
437
- if (
438
- training_configuration.data_config.mean is None
439
- or training_configuration.data_config.std is None
440
- ):
441
- raise ValueError("Mean and std must be provided in the training configuration.")
624
+ if configuration.data_config.mean is None or configuration.data_config.std is None:
625
+ raise ValueError("Mean and std must be provided in the configuration.")
442
626
 
627
+ # minimum transform
443
628
  if transforms is None:
444
629
  transforms = [
445
630
  {
@@ -447,13 +632,35 @@ def create_inference_configuration(
447
632
  },
448
633
  ]
449
634
 
450
- return InferenceModel(
451
- data_type=data_type or training_configuration.data_config.data_type,
635
+ # tile size for UNets
636
+ if tile_size is not None:
637
+ model = configuration.algorithm_config.model
638
+
639
+ if model.architecture == SupportedArchitecture.UNET.value:
640
+ # tile size must be equal to k*2^n, where n is the number of pooling layers
641
+ # (equal to the depth) and k is an integer
642
+ depth = model.depth
643
+ tile_increment = 2**depth
644
+
645
+ for i, t in enumerate(tile_size):
646
+ if t % tile_increment != 0:
647
+ raise ValueError(
648
+ f"Tile size must be divisible by {tile_increment} along all "
649
+ f"axes (got {t} for axis {i}). If your image size is smaller "
650
+ f"along one axis (e.g. Z), consider padding the image."
651
+ )
652
+
653
+ # tile overlaps must be specified
654
+ if tile_overlap is None:
655
+ raise ValueError("Tile overlap must be specified.")
656
+
657
+ return InferenceConfig(
658
+ data_type=data_type or configuration.data_config.data_type,
452
659
  tile_size=tile_size,
453
660
  tile_overlap=tile_overlap,
454
- axes=axes or training_configuration.data_config.axes,
455
- mean=training_configuration.data_config.mean,
456
- std=training_configuration.data_config.std,
661
+ axes=axes or configuration.data_config.axes,
662
+ mean=configuration.data_config.mean,
663
+ std=configuration.data_config.std,
457
664
  transforms=transforms,
458
665
  tta_transforms=tta_transforms,
459
666
  batch_size=batch_size,
@@ -1,4 +1,5 @@
1
1
  """Pydantic CAREamics configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import re
@@ -11,8 +12,8 @@ from bioimageio.spec.generic.v0_3 import CiteEntry
11
12
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
13
  from typing_extensions import Self
13
14
 
14
- from .algorithm_model import AlgorithmModel
15
- from .data_model import DataModel
15
+ from .algorithm_model import AlgorithmConfig
16
+ from .data_model import DataConfig
16
17
  from .references import (
17
18
  CARE,
18
19
  CUSTOM,
@@ -34,7 +35,7 @@ from .references import (
34
35
  StructN2VRef,
35
36
  )
36
37
  from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
37
- from .training_model import TrainingModel
38
+ from .training_model import TrainingConfig
38
39
  from .transformations.n2v_manipulate_model import (
39
40
  N2VManipulateModel,
40
41
  )
@@ -156,9 +157,10 @@ class Configuration(BaseModel):
156
157
  )
157
158
 
158
159
  # Sub-configurations
159
- algorithm_config: AlgorithmModel
160
- data_config: DataModel
161
- training_config: TrainingModel
160
+ algorithm_config: AlgorithmConfig
161
+
162
+ data_config: DataConfig
163
+ training_config: TrainingConfig
162
164
 
163
165
  @field_validator("experiment_name")
164
166
  @classmethod
@@ -237,25 +239,22 @@ class Configuration(BaseModel):
237
239
  Validated configuration.
238
240
  """
239
241
  if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
240
- # if we have a list of transform (as opposed to Compose)
241
- if self.data_config.has_transform_list():
242
- # missing N2V_MANIPULATE
243
- if not self.data_config.has_n2v_manipulate():
244
- self.data_config.transforms.append(
245
- N2VManipulateModel(
246
- name=SupportedTransform.N2V_MANIPULATE.value,
247
- )
242
+ # missing N2V_MANIPULATE
243
+ if not self.data_config.has_n2v_manipulate():
244
+ self.data_config.transforms.append(
245
+ N2VManipulateModel(
246
+ name=SupportedTransform.N2V_MANIPULATE.value,
248
247
  )
248
+ )
249
249
 
250
- median = SupportedPixelManipulation.MEDIAN.value
251
- uniform = SupportedPixelManipulation.UNIFORM.value
252
- strategy = median if self.algorithm_config.model.n2v2 else uniform
253
- self.data_config.set_N2V2_strategy(strategy)
250
+ median = SupportedPixelManipulation.MEDIAN.value
251
+ uniform = SupportedPixelManipulation.UNIFORM.value
252
+ strategy = median if self.algorithm_config.model.n2v2 else uniform
253
+ self.data_config.set_N2V2_strategy(strategy)
254
254
  else:
255
- # if we have a list of transform, remove N2V manipulate if present
256
- if self.data_config.has_transform_list():
257
- if self.data_config.has_n2v_manipulate():
258
- self.data_config.remove_n2v_manipulate()
255
+ # remove N2V manipulate if present
256
+ if self.data_config.has_n2v_manipulate():
257
+ self.data_config.remove_n2v_manipulate()
259
258
 
260
259
  return self
261
260
 
@@ -591,6 +590,6 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
591
590
  # save configuration as dictionary to yaml
592
591
  with open(config_path, "w") as f:
593
592
  # dump configuration
594
- yaml.dump(config.model_dump(), f, default_flow_style=False)
593
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
595
594
 
596
595
  return config_path