careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -4,11 +4,11 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
4
4
 
5
5
  from albumentations import Compose
6
6
 
7
- from .algorithm_model import AlgorithmModel
7
+ from .algorithm_model import AlgorithmConfig
8
8
  from .architectures import UNetModel
9
9
  from .configuration_model import Configuration
10
- from .data_model import DataModel
11
- from .inference_model import InferenceModel
10
+ from .data_model import DataConfig
11
+ from .inference_model import InferenceConfig
12
12
  from .support import (
13
13
  SupportedAlgorithm,
14
14
  SupportedArchitecture,
@@ -16,10 +16,11 @@ from .support import (
16
16
  SupportedPixelManipulation,
17
17
  SupportedTransform,
18
18
  )
19
- from .training_model import TrainingModel
19
+ from .training_model import TrainingConfig
20
20
 
21
21
 
22
- def create_n2n_configuration(
22
+ def _create_supervised_configuration(
23
+ algorithm: Literal["care", "n2n"],
23
24
  experiment_name: str,
24
25
  data_type: Literal["array", "tiff", "custom"],
25
26
  axes: str,
@@ -27,28 +28,18 @@ def create_n2n_configuration(
27
28
  batch_size: int,
28
29
  num_epochs: int,
29
30
  use_augmentations: bool = True,
30
- use_n2v2: bool = False,
31
- n_channels: int = 1,
31
+ loss: Literal["mae", "mse"] = "mae",
32
+ n_channels: int = -1,
32
33
  logger: Literal["wandb", "tensorboard", "none"] = "none",
33
34
  model_kwargs: Optional[dict] = None,
34
35
  ) -> Configuration:
35
36
  """
36
- Create a configuration for training N2V.
37
-
38
- If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
39
- 2.
40
-
41
- By setting `use_augmentations` to False, the only transformation applied will be
42
- normalization and N2V manipulation.
43
-
44
- The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed
45
- in `model_kwargs`.
46
-
47
- If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
48
- will be applied to each manipulated pixel.
37
+ Create a configuration for training CARE or Noise2Noise.
49
38
 
50
39
  Parameters
51
40
  ----------
41
+ algorithm : Literal["care", "n2n"]
42
+ Algorithm to use.
52
43
  experiment_name : str
53
44
  Name of the experiment.
54
45
  data_type : Literal["array", "tiff", "custom"]
@@ -63,18 +54,10 @@ def create_n2n_configuration(
63
54
  Number of epochs.
64
55
  use_augmentations : bool, optional
65
56
  Whether to use augmentations, by default True.
66
- use_n2v2 : bool, optional
67
- Whether to use N2V2, by default False.
57
+ loss : Literal["mae", "mse"], optional
58
+ Loss function to use, by default "mae".
68
59
  n_channels : int, optional
69
- Number of channels (in and out), by default 1.
70
- roi_size : int, optional
71
- N2V pixel manipulation area, by default 11.
72
- masked_pixel_percentage : float, optional
73
- Percentage of pixels masked in each patch, by default 0.2.
74
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
75
- Axis along which to apply structN2V mask, by default "none".
76
- struct_n2v_span : int, optional
77
- Span of the structN2V mask, by default 5.
60
+ Number of channels (in and out), by default -1.
78
61
  logger : Literal["wandb", "tensorboard", "none"], optional
79
62
  Logger to use, by default "none".
80
63
  model_kwargs : dict, optional
@@ -83,12 +66,23 @@ def create_n2n_configuration(
83
66
  Returns
84
67
  -------
85
68
  Configuration
86
- Configuration for training N2V.
69
+ Configuration for training CARE or Noise2Noise.
87
70
  """
71
+ # if there are channels, we need to specify their number
72
+ if "C" in axes and n_channels == 1:
73
+ raise ValueError(
74
+ f"Number of channels must be specified when using channels "
75
+ f"(got {n_channels} channel)."
76
+ )
77
+ elif "C" not in axes and n_channels > 1:
78
+ raise ValueError(
79
+ f"C is not present in the axes, but number of channels is specified "
80
+ f"(got {n_channels} channel)."
81
+ )
82
+
88
83
  # model
89
84
  if model_kwargs is None:
90
85
  model_kwargs = {}
91
- model_kwargs["n2v2"] = use_n2v2
92
86
  model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
93
87
  model_kwargs["in_channels"] = n_channels
94
88
  model_kwargs["num_classes"] = n_channels
@@ -99,9 +93,9 @@ def create_n2n_configuration(
99
93
  )
100
94
 
101
95
  # algorithm model
102
- algorithm = AlgorithmModel(
103
- algorithm=SupportedAlgorithm.N2V.value,
104
- loss=SupportedLoss.N2V.value,
96
+ algorithm = AlgorithmConfig(
97
+ algorithm=algorithm,
98
+ loss=loss,
105
99
  model=unet_model,
106
100
  )
107
101
 
@@ -126,7 +120,7 @@ def create_n2n_configuration(
126
120
  ]
127
121
 
128
122
  # data model
129
- data = DataModel(
123
+ data = DataConfig(
130
124
  data_type=data_type,
131
125
  axes=axes,
132
126
  patch_size=patch_size,
@@ -135,7 +129,7 @@ def create_n2n_configuration(
135
129
  )
136
130
 
137
131
  # training model
138
- training = TrainingModel(
132
+ training = TrainingConfig(
139
133
  num_epochs=num_epochs,
140
134
  batch_size=batch_size,
141
135
  logger=None if logger == "none" else logger,
@@ -152,6 +146,151 @@ def create_n2n_configuration(
152
146
  return configuration
153
147
 
154
148
 
149
+ def create_care_configuration(
150
+ experiment_name: str,
151
+ data_type: Literal["array", "tiff", "custom"],
152
+ axes: str,
153
+ patch_size: List[int],
154
+ batch_size: int,
155
+ num_epochs: int,
156
+ use_augmentations: bool = True,
157
+ loss: Literal["mae", "mse"] = "mae",
158
+ n_channels: int = 1,
159
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
160
+ model_kwargs: Optional[dict] = None,
161
+ ) -> Configuration:
162
+ """
163
+ Create a configuration for training CARE.
164
+
165
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
166
+ 2.
167
+
168
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
169
+ channels. Likewise, if you set the number of channels, then "C" must be present in
170
+ `axes`.
171
+
172
+ By setting `use_augmentations` to False, the only transformation applied will be
173
+ normalization.
174
+
175
+ Parameters
176
+ ----------
177
+ experiment_name : str
178
+ Name of the experiment.
179
+ data_type : Literal["array", "tiff", "custom"]
180
+ Type of the data.
181
+ axes : str
182
+ Axes of the data (e.g. SYX).
183
+ patch_size : List[int]
184
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
185
+ batch_size : int
186
+ Batch size.
187
+ num_epochs : int
188
+ Number of epochs.
189
+ use_augmentations : bool, optional
190
+ Whether to use augmentations, by default True.
191
+ loss : Literal["mae", "mse"], optional
192
+ Loss function to use, by default "mae".
193
+ n_channels : int, optional
194
+ Number of channels (in and out), by default 1.
195
+ logger : Literal["wandb", "tensorboard", "none"], optional
196
+ Logger to use, by default "none".
197
+ model_kwargs : dict, optional
198
+ UNetModel parameters, by default {}.
199
+
200
+ Returns
201
+ -------
202
+ Configuration
203
+ Configuration for training CARE.
204
+ """
205
+ return _create_supervised_configuration(
206
+ algorithm="care",
207
+ experiment_name=experiment_name,
208
+ data_type=data_type,
209
+ axes=axes,
210
+ patch_size=patch_size,
211
+ batch_size=batch_size,
212
+ num_epochs=num_epochs,
213
+ use_augmentations=use_augmentations,
214
+ loss=loss,
215
+ # TODO in the future we might support different in and out channels for CARE
216
+ n_channels=n_channels,
217
+ logger=logger,
218
+ model_kwargs=model_kwargs,
219
+ )
220
+
221
+
222
+ def create_n2n_configuration(
223
+ experiment_name: str,
224
+ data_type: Literal["array", "tiff", "custom"],
225
+ axes: str,
226
+ patch_size: List[int],
227
+ batch_size: int,
228
+ num_epochs: int,
229
+ use_augmentations: bool = True,
230
+ loss: Literal["mae", "mse"] = "mae",
231
+ n_channels: int = 1,
232
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
233
+ model_kwargs: Optional[dict] = None,
234
+ ) -> Configuration:
235
+ """
236
+ Create a configuration for training Noise2Noise.
237
+
238
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
239
+ 2.
240
+
241
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
242
+ channels. Likewise, if you set the number of channels, then "C" must be present in
243
+ `axes`.
244
+
245
+ By setting `use_augmentations` to False, the only transformation applied will be
246
+ normalization.
247
+
248
+ Parameters
249
+ ----------
250
+ experiment_name : str
251
+ Name of the experiment.
252
+ data_type : Literal["array", "tiff", "custom"]
253
+ Type of the data.
254
+ axes : str
255
+ Axes of the data (e.g. SYX).
256
+ patch_size : List[int]
257
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
258
+ batch_size : int
259
+ Batch size.
260
+ num_epochs : int
261
+ Number of epochs.
262
+ use_augmentations : bool, optional
263
+ Whether to use augmentations, by default True.
264
+ loss : Literal["mae", "mse"], optional
265
+ Loss function to use, by default "mae".
266
+ n_channels : int, optional
267
+ Number of channels (in and out), by default 1.
268
+ logger : Literal["wandb", "tensorboard", "none"], optional
269
+ Logger to use, by default "none".
270
+ model_kwargs : dict, optional
271
+ UNetModel parameters, by default {}.
272
+
273
+ Returns
274
+ -------
275
+ Configuration
276
+ Configuration for training Noise2Noise.
277
+ """
278
+ return _create_supervised_configuration(
279
+ algorithm="n2n",
280
+ experiment_name=experiment_name,
281
+ data_type=data_type,
282
+ axes=axes,
283
+ patch_size=patch_size,
284
+ batch_size=batch_size,
285
+ num_epochs=num_epochs,
286
+ use_augmentations=use_augmentations,
287
+ loss=loss,
288
+ n_channels=n_channels,
289
+ logger=logger,
290
+ model_kwargs=model_kwargs,
291
+ )
292
+
293
+
155
294
  def create_n2v_configuration(
156
295
  experiment_name: str,
157
296
  data_type: Literal["array", "tiff", "custom"],
@@ -161,7 +300,7 @@ def create_n2v_configuration(
161
300
  num_epochs: int,
162
301
  use_augmentations: bool = True,
163
302
  use_n2v2: bool = False,
164
- n_channels: int = -1,
303
+ n_channels: int = 1,
165
304
  roi_size: int = 11,
166
305
  masked_pixel_percentage: float = 0.2,
167
306
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
@@ -170,7 +309,7 @@ def create_n2v_configuration(
170
309
  model_kwargs: Optional[dict] = None,
171
310
  ) -> Configuration:
172
311
  """
173
- Create a configuration for training N2V.
312
+ Create a configuration for training Noise2Void.
174
313
 
175
314
  N2V uses a UNet model to denoise images in a self-supervised manner. To use its
176
315
  variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
@@ -220,7 +359,7 @@ def create_n2v_configuration(
220
359
  use_n2v2 : bool, optional
221
360
  Whether to use N2V2, by default False.
222
361
  n_channels : int, optional
223
- Number of channels (in and out), by default -1.
362
+ Number of channels (in and out), by default 1.
224
363
  roi_size : int, optional
225
364
  N2V pixel manipulation area, by default 11.
226
365
  masked_pixel_percentage : float, optional
@@ -300,18 +439,16 @@ def create_n2v_configuration(
300
439
  ... )
301
440
  """
302
441
  # if there are channels, we need to specify their number
303
- if "C" in axes and n_channels == -1:
442
+ if "C" in axes and n_channels == 1:
304
443
  raise ValueError(
305
444
  f"Number of channels must be specified when using channels "
306
445
  f"(got {n_channels} channel)."
307
446
  )
308
- elif "C" not in axes and n_channels != -1:
447
+ elif "C" not in axes and n_channels > 1:
309
448
  raise ValueError(
310
449
  f"C is not present in the axes, but number of channels is specified "
311
450
  f"(got {n_channels} channel)."
312
451
  )
313
- elif n_channels == -1:
314
- n_channels = 1
315
452
 
316
453
  # model
317
454
  if model_kwargs is None:
@@ -327,7 +464,7 @@ def create_n2v_configuration(
327
464
  )
328
465
 
329
466
  # algorithm model
330
- algorithm = AlgorithmModel(
467
+ algorithm = AlgorithmConfig(
331
468
  algorithm=SupportedAlgorithm.N2V.value,
332
469
  loss=SupportedLoss.N2V.value,
333
470
  model=unet_model,
@@ -367,7 +504,7 @@ def create_n2v_configuration(
367
504
  transforms.append(nv2_transform)
368
505
 
369
506
  # data model
370
- data = DataModel(
507
+ data = DataConfig(
371
508
  data_type=data_type,
372
509
  axes=axes,
373
510
  patch_size=patch_size,
@@ -376,7 +513,7 @@ def create_n2v_configuration(
376
513
  )
377
514
 
378
515
  # training model
379
- training = TrainingModel(
516
+ training = TrainingConfig(
380
517
  num_epochs=num_epochs,
381
518
  batch_size=batch_size,
382
519
  logger=None if logger == "none" else logger,
@@ -403,7 +540,7 @@ def create_inference_configuration(
403
540
  transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
404
541
  tta_transforms: bool = True,
405
542
  batch_size: Optional[int] = 1,
406
- ) -> InferenceModel:
543
+ ) -> InferenceConfig:
407
544
  """
408
545
  Create a configuration for inference with N2V.
409
546
 
@@ -447,7 +584,7 @@ def create_inference_configuration(
447
584
  },
448
585
  ]
449
586
 
450
- return InferenceModel(
587
+ return InferenceConfig(
451
588
  data_type=data_type or training_configuration.data_config.data_type,
452
589
  tile_size=tile_size,
453
590
  tile_overlap=tile_overlap,
@@ -11,8 +11,8 @@ from bioimageio.spec.generic.v0_3 import CiteEntry
11
11
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
12
  from typing_extensions import Self
13
13
 
14
- from .algorithm_model import AlgorithmModel
15
- from .data_model import DataModel
14
+ from .algorithm_model import AlgorithmConfig
15
+ from .data_model import DataConfig
16
16
  from .references import (
17
17
  CARE,
18
18
  CUSTOM,
@@ -34,7 +34,7 @@ from .references import (
34
34
  StructN2VRef,
35
35
  )
36
36
  from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
37
- from .training_model import TrainingModel
37
+ from .training_model import TrainingConfig
38
38
  from .transformations.n2v_manipulate_model import (
39
39
  N2VManipulateModel,
40
40
  )
@@ -156,9 +156,10 @@ class Configuration(BaseModel):
156
156
  )
157
157
 
158
158
  # Sub-configurations
159
- algorithm_config: AlgorithmModel
160
- data_config: DataModel
161
- training_config: TrainingModel
159
+ algorithm_config: AlgorithmConfig
160
+
161
+ data_config: DataConfig
162
+ training_config: TrainingConfig
162
163
 
163
164
  @field_validator("experiment_name")
164
165
  @classmethod
@@ -591,6 +592,6 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
591
592
  # save configuration as dictionary to yaml
592
593
  with open(config_path, "w") as f:
593
594
  # dump configuration
594
- yaml.dump(config.model_dump(), f, default_flow_style=False)
595
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
595
596
 
596
597
  return config_path
@@ -33,7 +33,7 @@ TRANSFORMS_UNION = Annotated[
33
33
  ]
34
34
 
35
35
 
36
- class DataModel(BaseModel):
36
+ class DataConfig(BaseModel):
37
37
  """
38
38
  Data configuration.
39
39
 
@@ -45,7 +45,7 @@ class DataModel(BaseModel):
45
45
  --------
46
46
  Minimum example:
47
47
 
48
- >>> data = DataModel(
48
+ >>> data = DataConfig(
49
49
  ... data_type="array", # defined in SupportedData
50
50
  ... patch_size=[128, 128],
51
51
  ... batch_size=4,
@@ -58,7 +58,7 @@ class DataModel(BaseModel):
58
58
  One can pass also a list of transformations, by keyword, using the
59
59
  SupportedTransform or the name of an Albumentation transform:
60
60
  >>> from careamics.config.support import SupportedTransform
61
- >>> data = DataModel(
61
+ >>> data = DataConfig(
62
62
  ... data_type="tiff",
63
63
  ... patch_size=[128, 128],
64
64
  ... batch_size=4,
@@ -14,7 +14,7 @@ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
14
14
  TRANSFORMS_UNION = Union[NormalizeModel]
15
15
 
16
16
 
17
- class InferenceModel(BaseModel):
17
+ class InferenceConfig(BaseModel):
18
18
  """Configuration class for the prediction model."""
19
19
 
20
20
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
@@ -15,7 +15,7 @@ class SupportedOptimizer(str, BaseEnum):
15
15
  # ASGD = "ASGD"
16
16
  # Adadelta = "Adadelta"
17
17
  # Adagrad = "Adagrad"
18
- Adam = "Adam"
18
+ ADAM = "Adam"
19
19
  # AdamW = "AdamW"
20
20
  # Adamax = "Adamax"
21
21
  # LBFGS = "LBFGS"
@@ -50,6 +50,6 @@ class SupportedScheduler(str, BaseEnum):
50
50
  # MultiplicativeLR = "MultiplicativeLR"
51
51
  # OneCycleLR = "OneCycleLR"
52
52
  # PolynomialLR = "PolynomialLR"
53
- ReduceLROnPlateau = "ReduceLROnPlateau"
53
+ REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
54
54
  # SequentialLR = "SequentialLR"
55
- StepLR = "StepLR"
55
+ STEP_LR = "StepLR"
@@ -13,7 +13,7 @@ from pydantic import (
13
13
  from .callback_model import CheckpointModel, EarlyStoppingModel
14
14
 
15
15
 
16
- class TrainingModel(BaseModel):
16
+ class TrainingConfig(BaseModel):
17
17
  """
18
18
  Parameters related to the training.
19
19
 
@@ -30,7 +30,7 @@ class N2VManipulateModel(TransformModel):
30
30
  validate_assignment=True,
31
31
  )
32
32
 
33
- name: Literal["N2VManipulate"]
33
+ name: Literal["N2VManipulate"] = "N2VManipulate"
34
34
  roi_size: int = Field(default=11, ge=3, le=21)
35
35
  masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
36
36
  strategy: Literal["uniform", "median"] = Field(default="uniform")
@@ -26,7 +26,7 @@ class NDFlipModel(TransformModel):
26
26
  validate_assignment=True,
27
27
  )
28
28
 
29
- name: Literal["NDFlip"]
29
+ name: Literal["NDFlip"] = "NDFlip"
30
30
  p: float = Field(default=0.5, ge=0.0, le=1.0)
31
31
  is_3D: bool = Field(default=False)
32
32
  flip_z: bool = Field(default=True)
@@ -26,6 +26,6 @@ class NormalizeModel(TransformModel):
26
26
  validate_assignment=True,
27
27
  )
28
28
 
29
- name: Literal["Normalize"]
29
+ name: Literal["Normalize"] = "Normalize"
30
30
  mean: float = Field(default=0.485) # albumentations defaults
31
31
  std: float = Field(default=0.229)
@@ -24,6 +24,6 @@ class XYRandomRotate90Model(TransformModel):
24
24
  validate_assignment=True,
25
25
  )
26
26
 
27
- name: Literal["XYRandomRotate90"]
27
+ name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
28
28
  p: float = Field(default=0.5, ge=0.0, le=1.0)
29
29
  is_3D: bool = Field(default=False)
@@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
8
8
  import numpy as np
9
9
  from torch.utils.data import Dataset
10
10
 
11
- from ..config import DataModel, InferenceModel
11
+ from ..config import DataConfig, InferenceConfig
12
12
  from ..config.tile_information import TileInformation
13
13
  from ..utils.logging import get_logger
14
14
  from .dataset_utils import read_tiff, reshape_array
@@ -29,7 +29,7 @@ class InMemoryDataset(Dataset):
29
29
 
30
30
  def __init__(
31
31
  self,
32
- data_config: DataModel,
32
+ data_config: DataConfig,
33
33
  inputs: Union[np.ndarray, List[Path]],
34
34
  data_target: Optional[Union[np.ndarray, List[Path]]] = None,
35
35
  read_source_func: Callable = read_tiff,
@@ -279,7 +279,7 @@ class InMemoryPredictionDataset(Dataset):
279
279
 
280
280
  def __init__(
281
281
  self,
282
- prediction_config: InferenceModel,
282
+ prediction_config: InferenceConfig,
283
283
  inputs: np.ndarray,
284
284
  data_target: Optional[np.ndarray] = None,
285
285
  read_source_func: Optional[Callable] = read_tiff,
@@ -7,7 +7,7 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
7
7
  import numpy as np
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
 
10
- from ..config import DataModel, InferenceModel
10
+ from ..config import DataConfig, InferenceConfig
11
11
  from ..config.tile_information import TileInformation
12
12
  from ..utils.logging import get_logger
13
13
  from .dataset_utils import read_tiff, reshape_array
@@ -46,7 +46,7 @@ class PathIterableDataset(IterableDataset):
46
46
 
47
47
  def __init__(
48
48
  self,
49
- data_config: Union[DataModel, InferenceModel],
49
+ data_config: Union[DataConfig, InferenceConfig],
50
50
  src_files: List[Path],
51
51
  target_files: Optional[List[Path]] = None,
52
52
  read_source_func: Callable = read_tiff,
@@ -346,7 +346,7 @@ class IterablePredictionDataset(PathIterableDataset):
346
346
 
347
347
  def __init__(
348
348
  self,
349
- prediction_config: InferenceModel,
349
+ prediction_config: InferenceConfig,
350
350
  src_files: List[Path],
351
351
  read_source_func: Callable = read_tiff,
352
352
  **kwargs: Any,