careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,7 @@ __all__ = [
14
14
  "create_care_configuration",
15
15
  "register_model",
16
16
  "CustomModel",
17
- "create_inference_configuration",
18
17
  "clear_custom_models",
19
- "ConfigurationInformation",
20
18
  ]
21
19
 
22
20
  from .algorithm_model import AlgorithmConfig
@@ -24,7 +22,6 @@ from .architectures import CustomModel, clear_custom_models, register_model
24
22
  from .callback_model import CheckpointModel
25
23
  from .configuration_factory import (
26
24
  create_care_configuration,
27
- create_inference_configuration,
28
25
  create_n2n_configuration,
29
26
  create_n2v_configuration,
30
27
  )
@@ -93,12 +93,20 @@ class AlgorithmConfig(BaseModel):
93
93
 
94
94
  # Mandatory fields
95
95
  algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
96
+ """Name of the algorithm, as defined in SupportedAlgorithm."""
97
+
96
98
  loss: Literal["n2v", "mae", "mse"]
99
+ """Loss function to use, as defined in SupportedLoss."""
100
+
97
101
  model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
102
+ """Model architecture to use, defined in SupportedArchitecture."""
98
103
 
99
104
  # Optional fields
100
105
  optimizer: OptimizerModel = OptimizerModel()
106
+ """Optimizer to use, defined in SupportedOptimizer."""
107
+
101
108
  lr_scheduler: LrSchedulerModel = LrSchedulerModel()
109
+ """Learning rate scheduler to use, defined in SupportedScheduler."""
102
110
 
103
111
  @model_validator(mode="after")
104
112
  def algorithm_cross_validation(self: Self) -> Self:
@@ -134,21 +142,6 @@ class AlgorithmConfig(BaseModel):
134
142
  "sure that `in_channels` and `num_classes` are the same."
135
143
  )
136
144
 
137
- # N2N
138
- if self.algorithm == "n2n":
139
- # n2n is only compatible with the UNet model
140
- if not isinstance(self.model, UNetModel):
141
- raise ValueError(
142
- f"Model for algorithm {self.algorithm} must be a `UNetModel`."
143
- )
144
-
145
- # n2n requires the number of input and output channels to be the same
146
- if self.model.in_channels != self.model.num_classes:
147
- raise ValueError(
148
- "N2N requires the same number of input and output channels. Make "
149
- "sure that `in_channels` and `num_classes` are the same."
150
- )
151
-
152
145
  if self.algorithm == "care" or self.algorithm == "n2n":
153
146
  if self.loss == "n2v":
154
147
  raise ValueError("Supervised algorithms do not support loss `n2v`.")
@@ -13,6 +13,7 @@ class ArchitectureModel(BaseModel):
13
13
  """
14
14
 
15
15
  architecture: str
16
+ """Name of the architecture."""
16
17
 
17
18
  def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
18
19
  """
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Any, Dict, Literal
6
+ from typing import Any, Literal
7
7
 
8
8
  from pydantic import ConfigDict, field_validator, model_validator
9
9
  from torch.nn import Module
@@ -72,9 +72,11 @@ class CustomModel(ArchitectureModel):
72
72
 
73
73
  # discriminator used for choosing the pydantic model in Model
74
74
  architecture: Literal["Custom"]
75
+ """Name of the architecture."""
75
76
 
76
77
  # name of the custom model
77
78
  name: str
79
+ """Name of the custom model."""
78
80
 
79
81
  @field_validator("name")
80
82
  @classmethod
@@ -136,7 +138,7 @@ class CustomModel(ArchitectureModel):
136
138
  """
137
139
  return pformat(self.model_dump())
138
140
 
139
- def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
141
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
140
142
  """Dump the model configuration.
141
143
 
142
144
  Parameters
@@ -146,7 +148,7 @@ class CustomModel(ArchitectureModel):
146
148
 
147
149
  Returns
148
150
  -------
149
- Dict[str, Any]
151
+ dict[str, Any]
150
152
  Model configuration.
151
153
  """
152
154
  model_dict = super().model_dump()
@@ -29,19 +29,38 @@ class UNetModel(ArchitectureModel):
29
29
 
30
30
  # discriminator used for choosing the pydantic model in Model
31
31
  architecture: Literal["UNet"]
32
+ """Name of the architecture."""
32
33
 
33
34
  # parameters
34
35
  # validate_defaults allow ignoring default values in the dump if they were not set
35
36
  conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
37
+ """Dimensions (2D or 3D) of the convolutional layers."""
38
+
36
39
  num_classes: int = Field(default=1, ge=1, validate_default=True)
40
+ """Number of classes or channels in the model output."""
41
+
37
42
  in_channels: int = Field(default=1, ge=1, validate_default=True)
43
+ """Number of channels in the input to the model."""
44
+
38
45
  depth: int = Field(default=2, ge=1, le=10, validate_default=True)
46
+ """Number of levels in the UNet."""
47
+
39
48
  num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
+ """Number of convolutional filters in the first layer of the UNet."""
50
+
40
51
  final_activation: Literal[
41
52
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
42
53
  ] = Field(default="None", validate_default=True)
54
+ """Final activation function."""
55
+
43
56
  n2v2: bool = Field(default=False, validate_default=True)
57
+ """Whether to use N2V2 architecture modifications, with blur pool layers and fewer
58
+ skip connections.
59
+ """
60
+
44
61
  independent_channels: bool = Field(default=True, validate_default=True)
62
+ """Whether information is processed independently in each channel, used to train
63
+ channels independently."""
45
64
 
46
65
  @field_validator("num_channels_init")
47
66
  @classmethod
@@ -17,6 +17,7 @@ class VAEModel(ArchitectureModel):
17
17
  )
18
18
 
19
19
  architecture: Literal["VAE"]
20
+ """Name of the architecture."""
20
21
 
21
22
  def set_3D(self, is_3D: bool) -> None:
22
23
  """
@@ -13,69 +13,111 @@ from pydantic import (
13
13
 
14
14
 
15
15
  class CheckpointModel(BaseModel):
16
- """Checkpoint saving callback Pydantic model."""
16
+ """Checkpoint saving callback Pydantic model.
17
+
18
+ The parameters corresponds to those of
19
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
20
+
21
+ See:
22
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
+ """
17
24
 
18
25
  model_config = ConfigDict(
19
26
  validate_assignment=True,
20
27
  )
21
28
 
22
29
  monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
+ """Quantity to monitor."""
31
+
23
32
  verbose: bool = Field(default=False, validate_default=True)
33
+ """Verbosity mode."""
34
+
24
35
  save_weights_only: bool = Field(default=False, validate_default=True)
36
+ """When `True`, only the model's weights will be saved (model.save_weights)."""
37
+
38
+ save_last: Optional[Literal[True, False, "link"]] = Field(
39
+ default=True, validate_default=True
40
+ )
41
+ """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
+
43
+ save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
44
+ """If `save_top_k == kz, the best k models according to the quantity monitored
45
+ will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
+ all models are saved."""
47
+
25
48
  mode: Literal["min", "max"] = Field(default="min", validate_default=True)
49
+ """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
+ save file is made based on either the maximization or the minimization of the
51
+ monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
+ be 'min', etc.
53
+ """
54
+
26
55
  auto_insert_metric_name: bool = Field(default=False, validate_default=True)
56
+ """When `True`, the checkpoints filenames will contain the metric name."""
57
+
27
58
  every_n_train_steps: Optional[int] = Field(
28
59
  default=None, ge=1, le=10, validate_default=True
29
60
  )
61
+ """Number of training steps between checkpoints."""
62
+
30
63
  train_time_interval: Optional[timedelta] = Field(
31
64
  default=None, validate_default=True
32
65
  )
66
+ """Checkpoints are monitored at the specified time interval."""
67
+
33
68
  every_n_epochs: Optional[int] = Field(
34
69
  default=None, ge=1, le=10, validate_default=True
35
70
  )
36
- save_last: Optional[Literal[True, False, "link"]] = Field(
37
- default=True, validate_default=True
38
- )
39
- save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
71
+ """Number of epochs between checkpoints."""
40
72
 
41
73
 
42
74
  class EarlyStoppingModel(BaseModel):
43
- """Early stopping callback Pydantic model."""
75
+ """Early stopping callback Pydantic model.
76
+
77
+ The parameters corresponds to those of
78
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
79
+
80
+ See:
81
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
82
+ """
44
83
 
45
84
  model_config = ConfigDict(
46
85
  validate_assignment=True,
47
86
  )
48
87
 
49
88
  monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
89
+ """Quantity to monitor."""
90
+
91
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
92
+ """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
+ absolute change of less than or equal to min_delta, will count as no improvement."""
94
+
50
95
  patience: int = Field(default=3, ge=1, le=10, validate_default=True)
96
+ """Number of checks with no improvement after which training will be stopped."""
97
+
98
+ verbose: bool = Field(default=False, validate_default=True)
99
+ """Verbosity mode."""
100
+
51
101
  mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
52
- min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
102
+ """One of {min, max, auto}."""
103
+
53
104
  check_finite: bool = Field(default=True, validate_default=True)
54
- stop_on_nan: bool = Field(default=True, validate_default=True)
55
- verbose: bool = Field(default=False, validate_default=True)
56
- restore_best_weights: bool = Field(default=True, validate_default=True)
57
- auto_lr_find: bool = Field(default=False, validate_default=True)
58
- auto_lr_find_patience: int = Field(default=3, ge=1, le=10, validate_default=True)
59
- auto_lr_find_mode: Literal["min", "max", "auto"] = Field(
60
- default="min", validate_default=True
61
- )
62
- auto_lr_find_direction: Literal["forward", "backward"] = Field(
63
- default="backward", validate_default=True
64
- )
65
- auto_lr_find_max_lr: float = Field(
66
- default=10.0, ge=0.0, le=1e6, validate_default=True
67
- )
68
- auto_lr_find_min_lr: float = Field(
69
- default=1e-8, ge=0.0, le=1e6, validate_default=True
70
- )
71
- auto_lr_find_num_training: int = Field(
72
- default=100, ge=1, le=1e6, validate_default=True
73
- )
74
- auto_lr_find_divergence_threshold: float = Field(
75
- default=5.0, ge=0.0, le=1e6, validate_default=True
76
- )
77
- auto_lr_find_accumulate_grad_batches: int = Field(
78
- default=1, ge=1, le=1e6, validate_default=True
105
+ """When `True`, stops training when the monitored quantity becomes `NaN` or
106
+ `inf`."""
107
+
108
+ stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
109
+ """Stop training immediately once the monitored quantity reaches this threshold."""
110
+
111
+ divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
112
+ """Stop training as soon as the monitored quantity becomes worse than this
113
+ threshold."""
114
+
115
+ check_on_train_epoch_end: Optional[bool] = Field(
116
+ default=False, validate_default=True
79
117
  )
80
- auto_lr_find_stop_divergence: bool = Field(default=True, validate_default=True)
81
- auto_lr_find_step_scale: float = Field(default=0.1, ge=0.0, le=10)
118
+ """Whether to run early stopping at the end of the training epoch. If this is
119
+ `False`, then the check runs at the end of the validation."""
120
+
121
+ log_rank_zero_only: bool = Field(default=False, validate_default=True)
122
+ """When set `True`, logs the status of the early stopping callback only for rank 0
123
+ process."""
@@ -1,12 +1,11 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional, Tuple
3
+ from typing import Any, Dict, List, Literal, Optional
4
4
 
5
5
  from .algorithm_model import AlgorithmConfig
6
6
  from .architectures import UNetModel
7
7
  from .configuration_model import Configuration
8
8
  from .data_model import DataConfig
9
- from .inference_model import InferenceConfig
10
9
  from .support import (
11
10
  SupportedAlgorithm,
12
11
  SupportedArchitecture,
@@ -107,9 +106,6 @@ def _create_supervised_configuration(
107
106
  # augmentations
108
107
  if use_augmentations:
109
108
  transforms: List[Dict[str, Any]] = [
110
- {
111
- "name": SupportedTransform.NORMALIZE.value,
112
- },
113
109
  {
114
110
  "name": SupportedTransform.XY_FLIP.value,
115
111
  },
@@ -118,11 +114,7 @@ def _create_supervised_configuration(
118
114
  },
119
115
  ]
120
116
  else:
121
- transforms = [
122
- {
123
- "name": SupportedTransform.NORMALIZE.value,
124
- },
125
- ]
117
+ transforms = []
126
118
 
127
119
  # data model
128
120
  data = DataConfig(
@@ -250,7 +242,8 @@ def create_n2n_configuration(
250
242
  use_augmentations: bool = True,
251
243
  independent_channels: bool = False,
252
244
  loss: Literal["mae", "mse"] = "mae",
253
- n_channels: int = 1,
245
+ n_channels_in: int = 1,
246
+ n_channels_out: int = -1,
254
247
  logger: Literal["wandb", "tensorboard", "none"] = "none",
255
248
  model_kwargs: Optional[dict] = None,
256
249
  ) -> Configuration:
@@ -260,10 +253,13 @@ def create_n2n_configuration(
260
253
  If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
261
254
  2.
262
255
 
263
- If "C" is present in `axes`, then you need to set `n_channels` to the number of
256
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
264
257
  channels. Likewise, if you set the number of channels, then "C" must be present in
265
258
  `axes`.
266
259
 
260
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
261
+ not specified, it will be assumed to be equal to `n_channels_in`.
262
+
267
263
  By default, all channels are trained together. To train all channels independently,
268
264
  set `independent_channels` to True.
269
265
 
@@ -290,8 +286,10 @@ def create_n2n_configuration(
290
286
  Whether to train all channels independently, by default False.
291
287
  loss : Literal["mae", "mse"], optional
292
288
  Loss function to use, by default "mae".
293
- n_channels : int, optional
294
- Number of channels (in and out), by default 1.
289
+ n_channels_in : int, optional
290
+ Number of channels in, by default 1.
291
+ n_channels_out : int, optional
292
+ Number of channels out, by default -1.
295
293
  logger : Literal["wandb", "tensorboard", "none"], optional
296
294
  Logger to use, by default "none".
297
295
  model_kwargs : dict, optional
@@ -302,6 +300,9 @@ def create_n2n_configuration(
302
300
  Configuration
303
301
  Configuration for training Noise2Noise.
304
302
  """
303
+ if n_channels_out == -1:
304
+ n_channels_out = n_channels_in
305
+
305
306
  return _create_supervised_configuration(
306
307
  algorithm="n2n",
307
308
  experiment_name=experiment_name,
@@ -313,8 +314,8 @@ def create_n2n_configuration(
313
314
  use_augmentations=use_augmentations,
314
315
  independent_channels=independent_channels,
315
316
  loss=loss,
316
- n_channels_in=n_channels,
317
- n_channels_out=n_channels,
317
+ n_channels_in=n_channels_in,
318
+ n_channels_out=n_channels_out,
318
319
  logger=logger,
319
320
  model_kwargs=model_kwargs,
320
321
  )
@@ -522,9 +523,6 @@ def create_n2v_configuration(
522
523
  # augmentations
523
524
  if use_augmentations:
524
525
  transforms: List[Dict[str, Any]] = [
525
- {
526
- "name": SupportedTransform.NORMALIZE.value,
527
- },
528
526
  {
529
527
  "name": SupportedTransform.XY_FLIP.value,
530
528
  },
@@ -533,11 +531,7 @@ def create_n2v_configuration(
533
531
  },
534
532
  ]
535
533
  else:
536
- transforms = [
537
- {
538
- "name": SupportedTransform.NORMALIZE.value,
539
- },
540
- ]
534
+ transforms = []
541
535
 
542
536
  # n2v2 and structn2v
543
537
  nv2_transform = {
@@ -579,77 +573,3 @@ def create_n2v_configuration(
579
573
  )
580
574
 
581
575
  return configuration
582
-
583
-
584
- def create_inference_configuration(
585
- configuration: Configuration,
586
- tile_size: Optional[Tuple[int, ...]] = None,
587
- tile_overlap: Optional[Tuple[int, ...]] = None,
588
- data_type: Optional[Literal["array", "tiff", "custom"]] = None,
589
- axes: Optional[str] = None,
590
- tta_transforms: bool = True,
591
- batch_size: Optional[int] = 1,
592
- ) -> InferenceConfig:
593
- """
594
- Create a configuration for inference with N2V.
595
-
596
- If not provided, `data_type` and `axes` are taken from the training
597
- configuration.
598
-
599
- Parameters
600
- ----------
601
- configuration : Configuration
602
- Global configuration.
603
- tile_size : Tuple[int, ...], optional
604
- Size of the tiles.
605
- tile_overlap : Tuple[int, ...], optional
606
- Overlap of the tiles.
607
- data_type : str, optional
608
- Type of the data, by default "tiff".
609
- axes : str, optional
610
- Axes of the data, by default "YX".
611
- tta_transforms : bool, optional
612
- Whether to apply test-time augmentations, by default True.
613
- batch_size : int, optional
614
- Batch size, by default 1.
615
-
616
- Returns
617
- -------
618
- InferenceConfiguration
619
- Configuration used to configure CAREamicsPredictData.
620
- """
621
- if configuration.data_config.mean is None or configuration.data_config.std is None:
622
- raise ValueError("Mean and std must be provided in the configuration.")
623
-
624
- # tile size for UNets
625
- if tile_size is not None:
626
- model = configuration.algorithm_config.model
627
-
628
- if model.architecture == SupportedArchitecture.UNET.value:
629
- # tile size must be equal to k*2^n, where n is the number of pooling layers
630
- # (equal to the depth) and k is an integer
631
- depth = model.depth
632
- tile_increment = 2**depth
633
-
634
- for i, t in enumerate(tile_size):
635
- if t % tile_increment != 0:
636
- raise ValueError(
637
- f"Tile size must be divisible by {tile_increment} along all "
638
- f"axes (got {t} for axis {i}). If your image size is smaller "
639
- f"along one axis (e.g. Z), consider padding the image."
640
- )
641
-
642
- # tile overlaps must be specified
643
- if tile_overlap is None:
644
- raise ValueError("Tile overlap must be specified.")
645
-
646
- return InferenceConfig(
647
- data_type=data_type or configuration.data_config.data_type,
648
- tile_size=tile_size,
649
- tile_overlap=tile_overlap,
650
- axes=axes or configuration.data_config.axes,
651
- mean=configuration.data_config.mean,
652
- std=configuration.data_config.std,
653
- tta_transforms=tta_transforms,
654
- batch_size=batch_size,
655
- )
@@ -5,11 +5,11 @@ from __future__ import annotations
5
5
  import re
6
6
  from pathlib import Path
7
7
  from pprint import pformat
8
- from typing import Dict, List, Literal, Union
8
+ from typing import Literal, Union
9
9
 
10
10
  import yaml
11
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
12
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
+ from pydantic import BaseModel, ConfigDict, field_validator, model_validator
13
13
  from typing_extensions import Self
14
14
 
15
15
  from .algorithm_model import AlgorithmConfig
@@ -147,20 +147,25 @@ class Configuration(BaseModel):
147
147
  )
148
148
 
149
149
  # version
150
- version: Literal["0.1.0"] = Field(
151
- default="0.1.0", description="Version of the CAREamics configuration."
152
- )
150
+ version: Literal["0.1.0"] = "0.1.0"
151
+ """CAREamics configuration version."""
153
152
 
154
153
  # required parameters
155
- experiment_name: str = Field(
156
- ..., description="Name of the experiment, used to name logs and checkpoints."
157
- )
154
+ experiment_name: str
155
+ """Name of the experiment, used to name logs and checkpoints."""
158
156
 
159
157
  # Sub-configurations
160
158
  algorithm_config: AlgorithmConfig
159
+ """Algorithm configuration, holding all parameters required to configure the
160
+ model."""
161
161
 
162
162
  data_config: DataConfig
163
+ """Data configuration, holding all parameters required to configure the training
164
+ data loader."""
165
+
163
166
  training_config: TrainingConfig
167
+ """Training configuration, holding all parameters required to configure the
168
+ training process."""
164
169
 
165
170
  @field_validator("experiment_name")
166
171
  @classmethod
@@ -269,7 +274,7 @@ class Configuration(BaseModel):
269
274
  """
270
275
  return pformat(self.model_dump())
271
276
 
272
- def set_3D(self, is_3D: bool, axes: str, patch_size: List[int]) -> None:
277
+ def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
273
278
  """
274
279
  Set 3D flag and axes.
275
280
 
@@ -279,7 +284,7 @@ class Configuration(BaseModel):
279
284
  Whether the algorithm is 3D or not.
280
285
  axes : str
281
286
  Axes of the data.
282
- patch_size : List[int]
287
+ patch_size : list[int]
283
288
  Patch size.
284
289
  """
285
290
  # set the flag and axes (this will not trigger validation at the config level)
@@ -389,7 +394,7 @@ class Configuration(BaseModel):
389
394
 
390
395
  return ""
391
396
 
392
- def get_algorithm_citations(self) -> List[CiteEntry]:
397
+ def get_algorithm_citations(self) -> list[CiteEntry]:
393
398
  """
394
399
  Return a list of citation entries of the current algorithm.
395
400
 
@@ -455,13 +460,13 @@ class Configuration(BaseModel):
455
460
 
456
461
  return ""
457
462
 
458
- def get_algorithm_keywords(self) -> List[str]:
463
+ def get_algorithm_keywords(self) -> list[str]:
459
464
  """
460
465
  Get algorithm keywords.
461
466
 
462
467
  Returns
463
468
  -------
464
- List[str]
469
+ list[str]
465
470
  List of keywords.
466
471
  """
467
472
  if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
@@ -491,8 +496,8 @@ class Configuration(BaseModel):
491
496
  self,
492
497
  exclude_defaults: bool = False,
493
498
  exclude_none: bool = True,
494
- **kwargs: Dict,
495
- ) -> Dict:
499
+ **kwargs: dict,
500
+ ) -> dict:
496
501
  """
497
502
  Override model_dump method in order to set default values.
498
503
 
@@ -503,7 +508,7 @@ class Configuration(BaseModel):
503
508
  True.
504
509
  exclude_none : bool, optional
505
510
  Whether to exclude fields with None values or not, by default True.
506
- **kwargs : Dict
511
+ **kwargs : dict
507
512
  Keyword arguments.
508
513
 
509
514
  Returns
@@ -524,7 +529,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
524
529
 
525
530
  Parameters
526
531
  ----------
527
- path : Union[str, Path]
532
+ path : str or Path
528
533
  Path to the configuration.
529
534
 
530
535
  Returns
@@ -556,7 +561,7 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
556
561
  ----------
557
562
  config : Configuration
558
563
  Configuration to save.
559
- path : Union[str, Path]
564
+ path : str or Path
560
565
  Path to a existing folder in which to save the configuration or to an existing
561
566
  configuration file.
562
567