careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -134,21 +134,6 @@ class AlgorithmConfig(BaseModel):
134
134
  "sure that `in_channels` and `num_classes` are the same."
135
135
  )
136
136
 
137
- # N2N
138
- if self.algorithm == "n2n":
139
- # n2n is only compatible with the UNet model
140
- if not isinstance(self.model, UNetModel):
141
- raise ValueError(
142
- f"Model for algorithm {self.algorithm} must be a `UNetModel`."
143
- )
144
-
145
- # n2n requires the number of input and output channels to be the same
146
- if self.model.in_channels != self.model.num_classes:
147
- raise ValueError(
148
- "N2N requires the same number of input and output channels. Make "
149
- "sure that `in_channels` and `num_classes` are the same."
150
- )
151
-
152
137
  if self.algorithm == "care" or self.algorithm == "n2n":
153
138
  if self.loss == "n2v":
154
139
  raise ValueError("Supervised algorithms do not support loss `n2v`.")
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Any, Dict, Literal
6
+ from typing import Any, Literal
7
7
 
8
8
  from pydantic import ConfigDict, field_validator, model_validator
9
9
  from torch.nn import Module
@@ -136,7 +136,7 @@ class CustomModel(ArchitectureModel):
136
136
  """
137
137
  return pformat(self.model_dump())
138
138
 
139
- def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
139
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
140
140
  """Dump the model configuration.
141
141
 
142
142
  Parameters
@@ -146,7 +146,7 @@ class CustomModel(ArchitectureModel):
146
146
 
147
147
  Returns
148
148
  -------
149
- Dict[str, Any]
149
+ dict[str, Any]
150
150
  Model configuration.
151
151
  """
152
152
  model_dict = super().model_dump()
@@ -54,9 +54,6 @@ def full_configuration_example() -> Configuration:
54
54
  batch_size=8,
55
55
  axes="YX",
56
56
  transforms=[
57
- {
58
- "name": SupportedTransform.NORMALIZE.value,
59
- },
60
57
  {
61
58
  "name": SupportedTransform.XY_FLIP.value,
62
59
  },
@@ -107,9 +107,6 @@ def _create_supervised_configuration(
107
107
  # augmentations
108
108
  if use_augmentations:
109
109
  transforms: List[Dict[str, Any]] = [
110
- {
111
- "name": SupportedTransform.NORMALIZE.value,
112
- },
113
110
  {
114
111
  "name": SupportedTransform.XY_FLIP.value,
115
112
  },
@@ -118,11 +115,7 @@ def _create_supervised_configuration(
118
115
  },
119
116
  ]
120
117
  else:
121
- transforms = [
122
- {
123
- "name": SupportedTransform.NORMALIZE.value,
124
- },
125
- ]
118
+ transforms = []
126
119
 
127
120
  # data model
128
121
  data = DataConfig(
@@ -250,7 +243,8 @@ def create_n2n_configuration(
250
243
  use_augmentations: bool = True,
251
244
  independent_channels: bool = False,
252
245
  loss: Literal["mae", "mse"] = "mae",
253
- n_channels: int = 1,
246
+ n_channels_in: int = 1,
247
+ n_channels_out: int = -1,
254
248
  logger: Literal["wandb", "tensorboard", "none"] = "none",
255
249
  model_kwargs: Optional[dict] = None,
256
250
  ) -> Configuration:
@@ -260,10 +254,13 @@ def create_n2n_configuration(
260
254
  If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
261
255
  2.
262
256
 
263
- If "C" is present in `axes`, then you need to set `n_channels` to the number of
257
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
264
258
  channels. Likewise, if you set the number of channels, then "C" must be present in
265
259
  `axes`.
266
260
 
261
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
262
+ not specified, it will be assumed to be equal to `n_channels_in`.
263
+
267
264
  By default, all channels are trained together. To train all channels independently,
268
265
  set `independent_channels` to True.
269
266
 
@@ -290,8 +287,10 @@ def create_n2n_configuration(
290
287
  Whether to train all channels independently, by default False.
291
288
  loss : Literal["mae", "mse"], optional
292
289
  Loss function to use, by default "mae".
293
- n_channels : int, optional
294
- Number of channels (in and out), by default 1.
290
+ n_channels_in : int, optional
291
+ Number of channels in, by default 1.
292
+ n_channels_out : int, optional
293
+ Number of channels out, by default -1.
295
294
  logger : Literal["wandb", "tensorboard", "none"], optional
296
295
  Logger to use, by default "none".
297
296
  model_kwargs : dict, optional
@@ -302,6 +301,9 @@ def create_n2n_configuration(
302
301
  Configuration
303
302
  Configuration for training Noise2Noise.
304
303
  """
304
+ if n_channels_out == -1:
305
+ n_channels_out = n_channels_in
306
+
305
307
  return _create_supervised_configuration(
306
308
  algorithm="n2n",
307
309
  experiment_name=experiment_name,
@@ -313,8 +315,8 @@ def create_n2n_configuration(
313
315
  use_augmentations=use_augmentations,
314
316
  independent_channels=independent_channels,
315
317
  loss=loss,
316
- n_channels_in=n_channels,
317
- n_channels_out=n_channels,
318
+ n_channels_in=n_channels_in,
319
+ n_channels_out=n_channels_out,
318
320
  logger=logger,
319
321
  model_kwargs=model_kwargs,
320
322
  )
@@ -522,9 +524,6 @@ def create_n2v_configuration(
522
524
  # augmentations
523
525
  if use_augmentations:
524
526
  transforms: List[Dict[str, Any]] = [
525
- {
526
- "name": SupportedTransform.NORMALIZE.value,
527
- },
528
527
  {
529
528
  "name": SupportedTransform.XY_FLIP.value,
530
529
  },
@@ -533,11 +532,7 @@ def create_n2v_configuration(
533
532
  },
534
533
  ]
535
534
  else:
536
- transforms = [
537
- {
538
- "name": SupportedTransform.NORMALIZE.value,
539
- },
540
- ]
535
+ transforms = []
541
536
 
542
537
  # n2v2 and structn2v
543
538
  nv2_transform = {
@@ -618,7 +613,10 @@ def create_inference_configuration(
618
613
  InferenceConfiguration
619
614
  Configuration used to configure CAREamicsPredictData.
620
615
  """
621
- if configuration.data_config.mean is None or configuration.data_config.std is None:
616
+ if (
617
+ configuration.data_config.image_means is None
618
+ or configuration.data_config.image_stds is None
619
+ ):
622
620
  raise ValueError("Mean and std must be provided in the configuration.")
623
621
 
624
622
  # tile size for UNets
@@ -648,8 +646,8 @@ def create_inference_configuration(
648
646
  tile_size=tile_size,
649
647
  tile_overlap=tile_overlap,
650
648
  axes=axes or configuration.data_config.axes,
651
- mean=configuration.data_config.mean,
652
- std=configuration.data_config.std,
649
+ image_means=configuration.data_config.image_means,
650
+ image_stds=configuration.data_config.image_stds,
653
651
  tta_transforms=tta_transforms,
654
652
  batch_size=batch_size,
655
653
  )
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import re
6
6
  from pathlib import Path
7
7
  from pprint import pformat
8
- from typing import Dict, List, Literal, Union
8
+ from typing import Literal, Union
9
9
 
10
10
  import yaml
11
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
@@ -269,7 +269,7 @@ class Configuration(BaseModel):
269
269
  """
270
270
  return pformat(self.model_dump())
271
271
 
272
- def set_3D(self, is_3D: bool, axes: str, patch_size: List[int]) -> None:
272
+ def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
273
273
  """
274
274
  Set 3D flag and axes.
275
275
 
@@ -279,7 +279,7 @@ class Configuration(BaseModel):
279
279
  Whether the algorithm is 3D or not.
280
280
  axes : str
281
281
  Axes of the data.
282
- patch_size : List[int]
282
+ patch_size : list[int]
283
283
  Patch size.
284
284
  """
285
285
  # set the flag and axes (this will not trigger validation at the config level)
@@ -389,7 +389,7 @@ class Configuration(BaseModel):
389
389
 
390
390
  return ""
391
391
 
392
- def get_algorithm_citations(self) -> List[CiteEntry]:
392
+ def get_algorithm_citations(self) -> list[CiteEntry]:
393
393
  """
394
394
  Return a list of citation entries of the current algorithm.
395
395
 
@@ -455,13 +455,13 @@ class Configuration(BaseModel):
455
455
 
456
456
  return ""
457
457
 
458
- def get_algorithm_keywords(self) -> List[str]:
458
+ def get_algorithm_keywords(self) -> list[str]:
459
459
  """
460
460
  Get algorithm keywords.
461
461
 
462
462
  Returns
463
463
  -------
464
- List[str]
464
+ list[str]
465
465
  List of keywords.
466
466
  """
467
467
  if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
@@ -491,8 +491,8 @@ class Configuration(BaseModel):
491
491
  self,
492
492
  exclude_defaults: bool = False,
493
493
  exclude_none: bool = True,
494
- **kwargs: Dict,
495
- ) -> Dict:
494
+ **kwargs: dict,
495
+ ) -> dict:
496
496
  """
497
497
  Override model_dump method in order to set default values.
498
498
 
@@ -503,7 +503,7 @@ class Configuration(BaseModel):
503
503
  True.
504
504
  exclude_none : bool, optional
505
505
  Whether to exclude fields with None values or not, by default True.
506
- **kwargs : Dict
506
+ **kwargs : dict
507
507
  Keyword arguments.
508
508
 
509
509
  Returns
@@ -524,7 +524,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
524
524
 
525
525
  Parameters
526
526
  ----------
527
- path : Union[str, Path]
527
+ path : str or Path
528
528
  Path to the configuration.
529
529
 
530
530
  Returns
@@ -556,7 +556,7 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
556
556
  ----------
557
557
  config : Configuration
558
558
  Configuration to save.
559
- path : Union[str, Path]
559
+ path : str or Path
560
560
  Path to a existing folder in which to save the configuration or to an existing
561
561
  configuration file.
562
562
 
@@ -3,8 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Any, List, Literal, Optional, Union
6
+ from typing import Any, Literal, Optional, Union
7
7
 
8
+ from numpy.typing import NDArray
8
9
  from pydantic import (
9
10
  BaseModel,
10
11
  ConfigDict,
@@ -17,7 +18,6 @@ from typing_extensions import Annotated, Self
17
18
 
18
19
  from .support import SupportedTransform
19
20
  from .transformations.n2v_manipulate_model import N2VManipulateModel
20
- from .transformations.normalize_model import NormalizeModel
21
21
  from .transformations.xy_flip_model import XYFlipModel
22
22
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
23
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -26,7 +26,6 @@ TRANSFORMS_UNION = Annotated[
26
26
  Union[
27
27
  XYFlipModel,
28
28
  XYRandomRotate90Model,
29
- NormalizeModel,
30
29
  N2VManipulateModel,
31
30
  ],
32
31
  Discriminator("name"), # used to tell the different transform models apart
@@ -39,7 +38,9 @@ class DataConfig(BaseModel):
39
38
 
40
39
  If std is specified, mean must be specified as well. Note that setting the std first
41
40
  and then the mean (if they were both `None` before) will raise a validation error.
42
- Prefer instead `set_mean_and_std` to set both at once.
41
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
42
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
43
+ of the target could be different from the input data.
43
44
 
44
45
  All supported transforms are defined in the SupportedTransform enum.
45
46
 
@@ -55,7 +56,7 @@ class DataConfig(BaseModel):
55
56
  ... )
56
57
 
57
58
  To change the mean and std of the data:
58
- >>> data.set_mean_and_std(mean=214.3, std=84.5)
59
+ >>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
59
60
 
60
61
  One can pass also a list of transformations, by keyword, using the
61
62
  SupportedTransform value:
@@ -67,11 +68,6 @@ class DataConfig(BaseModel):
67
68
  ... axes="YX",
68
69
  ... transforms=[
69
70
  ... {
70
- ... "name": SupportedTransform.NORMALIZE.value,
71
- ... "mean": 167.6,
72
- ... "std": 47.2,
73
- ... },
74
- ... {
75
71
  ... "name": "XYFlip",
76
72
  ... }
77
73
  ... ]
@@ -85,19 +81,24 @@ class DataConfig(BaseModel):
85
81
 
86
82
  # Dataset configuration
87
83
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
88
- patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3)
84
+ patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
89
85
  batch_size: int = Field(default=1, ge=1, validate_default=True)
90
86
  axes: str
91
87
 
92
88
  # Optional fields
93
- mean: Optional[float] = None
94
- std: Optional[float] = None
89
+ image_means: Optional[list[float]] = Field(
90
+ default=None, min_length=0, max_length=32
91
+ )
92
+ image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
93
+ target_means: Optional[list[float]] = Field(
94
+ default=None, min_length=0, max_length=32
95
+ )
96
+ target_stds: Optional[list[float]] = Field(
97
+ default=None, min_length=0, max_length=32
98
+ )
95
99
 
96
- transforms: List[TRANSFORMS_UNION] = Field(
100
+ transforms: list[TRANSFORMS_UNION] = Field(
97
101
  default=[
98
- {
99
- "name": SupportedTransform.NORMALIZE.value,
100
- },
101
102
  {
102
103
  "name": SupportedTransform.XY_FLIP.value,
103
104
  },
@@ -116,8 +117,8 @@ class DataConfig(BaseModel):
116
117
  @field_validator("patch_size")
117
118
  @classmethod
118
119
  def all_elements_power_of_2_minimum_8(
119
- cls, patch_list: Union[List[int]]
120
- ) -> Union[List[int]]:
120
+ cls, patch_list: Union[list[int]]
121
+ ) -> Union[list[int]]:
121
122
  """
122
123
  Validate patch size.
123
124
 
@@ -125,12 +126,12 @@ class DataConfig(BaseModel):
125
126
 
126
127
  Parameters
127
128
  ----------
128
- patch_list : Union[List[int]]
129
+ patch_list : list of int
129
130
  Patch size.
130
131
 
131
132
  Returns
132
133
  -------
133
- Union[List[int]]
134
+ list of int
134
135
  Validated patch size.
135
136
 
136
137
  Raises
@@ -180,19 +181,19 @@ class DataConfig(BaseModel):
180
181
  @field_validator("transforms")
181
182
  @classmethod
182
183
  def validate_prediction_transforms(
183
- cls, transforms: List[TRANSFORMS_UNION]
184
- ) -> List[TRANSFORMS_UNION]:
184
+ cls, transforms: list[TRANSFORMS_UNION]
185
+ ) -> list[TRANSFORMS_UNION]:
185
186
  """
186
187
  Validate N2VManipulate transform position in the transform list.
187
188
 
188
189
  Parameters
189
190
  ----------
190
- transforms : List[Transformations_Union]
191
+ transforms : list[Transformations_Union]
191
192
  Transforms.
192
193
 
193
194
  Returns
194
195
  -------
195
- List[TRANSFORMS_UNION]
196
+ list of transforms
196
197
  Validated transforms.
197
198
 
198
199
  Raises
@@ -235,29 +236,33 @@ class DataConfig(BaseModel):
235
236
  If std is not None and mean is None.
236
237
  """
237
238
  # check that mean and std are either both None, or both specified
238
- if (self.mean is None) != (self.std is None):
239
+ if (self.image_means and not self.image_stds) or (
240
+ self.image_stds and not self.image_means
241
+ ):
239
242
  raise ValueError(
240
243
  "Mean and std must be either both None, or both specified."
241
244
  )
242
245
 
243
- return self
246
+ elif (self.image_means is not None and self.image_stds is not None) and (
247
+ len(self.image_means) != len(self.image_stds)
248
+ ):
249
+ raise ValueError(
250
+ "Mean and std must be specified for each " "input channel."
251
+ )
244
252
 
245
- @model_validator(mode="after")
246
- def add_std_and_mean_to_normalize(self: Self) -> Self:
247
- """
248
- Add mean and std to the Normalize transform if it is present.
253
+ if (self.target_means and not self.target_stds) or (
254
+ self.target_stds and not self.target_means
255
+ ):
256
+ raise ValueError(
257
+ "Mean and std must be either both None, or both specified "
258
+ )
249
259
 
250
- Returns
251
- -------
252
- Self
253
- Data model with mean and std added to the Normalize transform.
254
- """
255
- if self.mean is not None and self.std is not None:
256
- # search in the transforms for Normalize and update parameters
257
- for transform in self.transforms:
258
- if transform.name == SupportedTransform.NORMALIZE.value:
259
- transform.mean = self.mean
260
- transform.std = self.std
260
+ elif self.target_means is not None and self.target_stds is not None:
261
+ if len(self.target_means) != len(self.target_stds):
262
+ raise ValueError(
263
+ "Mean and std must be either both None, or both specified for each "
264
+ "target channel."
265
+ )
261
266
 
262
267
  return self
263
268
 
@@ -341,7 +346,13 @@ class DataConfig(BaseModel):
341
346
  if self.has_n2v_manipulate():
342
347
  self.transforms.pop(-1)
343
348
 
344
- def set_mean_and_std(self, mean: float, std: float) -> None:
349
+ def set_mean_and_std(
350
+ self,
351
+ image_means: Union[NDArray, tuple, list, None],
352
+ image_stds: Union[NDArray, tuple, list, None],
353
+ target_means: Optional[Union[NDArray, tuple, list, None]] = None,
354
+ target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
355
+ ) -> None:
345
356
  """
346
357
  Set mean and standard deviation of the data.
347
358
 
@@ -350,14 +361,33 @@ class DataConfig(BaseModel):
350
361
 
351
362
  Parameters
352
363
  ----------
353
- mean : float
354
- Mean of the data.
355
- std : float
356
- Standard deviation of the data.
364
+ image_means : NDArray or tuple or list
365
+ Mean values for normalization.
366
+ image_stds : NDArray or tuple or list
367
+ Standard deviation values for normalization.
368
+ target_means : NDArray or tuple or list, optional
369
+ Target mean values for normalization, by default ().
370
+ target_stds : NDArray or tuple or list, optional
371
+ Target standard deviation values for normalization, by default ().
357
372
  """
358
- self._update(mean=mean, std=std)
373
+ # make sure we pass a list
374
+ if image_means is not None:
375
+ image_means = list(image_means)
376
+ if image_stds is not None:
377
+ image_stds = list(image_stds)
378
+ if target_means is not None:
379
+ target_means = list(target_means)
380
+ if target_stds is not None:
381
+ target_stds = list(target_stds)
382
+
383
+ self._update(
384
+ image_means=image_means,
385
+ image_stds=image_stds,
386
+ target_means=target_means,
387
+ target_stds=target_stds,
388
+ )
359
389
 
360
- def set_3D(self, axes: str, patch_size: List[int]) -> None:
390
+ def set_3D(self, axes: str, patch_size: list[int]) -> None:
361
391
  """
362
392
  Set 3D parameters.
363
393
 
@@ -365,7 +395,7 @@ class DataConfig(BaseModel):
365
395
  ----------
366
396
  axes : str
367
397
  Axes.
368
- patch_size : List[int]
398
+ patch_size : list of int
369
399
  Patch size.
370
400
  """
371
401
  self._update(axes=axes, patch_size=patch_size)
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, List, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
@@ -17,17 +17,17 @@ class InferenceConfig(BaseModel):
17
17
 
18
18
  # Mandatory fields
19
19
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
20
- tile_size: Optional[Union[List[int]]] = Field(
20
+ tile_size: Optional[Union[list[int]]] = Field(
21
21
  default=None, min_length=2, max_length=3
22
22
  )
23
- tile_overlap: Optional[Union[List[int]]] = Field(
23
+ tile_overlap: Optional[Union[list[int]]] = Field(
24
24
  default=None, min_length=2, max_length=3
25
25
  )
26
26
 
27
27
  axes: str
28
28
 
29
- mean: float
30
- std: float = Field(..., ge=0.0)
29
+ image_means: list = Field(..., min_length=0, max_length=32)
30
+ image_stds: list = Field(..., min_length=0, max_length=32)
31
31
 
32
32
  # only default TTAs are supported for now
33
33
  tta_transforms: bool = Field(default=True)
@@ -38,8 +38,8 @@ class InferenceConfig(BaseModel):
38
38
  @field_validator("tile_overlap")
39
39
  @classmethod
40
40
  def all_elements_non_zero_even(
41
- cls, tile_overlap: Optional[Union[List[int]]]
42
- ) -> Optional[Union[List[int]]]:
41
+ cls, tile_overlap: Optional[list[int]]
42
+ ) -> Optional[list[int]]:
43
43
  """
44
44
  Validate tile overlap.
45
45
 
@@ -47,12 +47,12 @@ class InferenceConfig(BaseModel):
47
47
 
48
48
  Parameters
49
49
  ----------
50
- tile_overlap : Optional[Union[List[int]]]
50
+ tile_overlap : list[int] or None
51
51
  Patch size.
52
52
 
53
53
  Returns
54
54
  -------
55
- Optional[Union[List[int]]]
55
+ list[int] or None
56
56
  Validated tile overlap.
57
57
 
58
58
  Raises
@@ -77,19 +77,19 @@ class InferenceConfig(BaseModel):
77
77
  @field_validator("tile_size")
78
78
  @classmethod
79
79
  def tile_min_8_power_of_2(
80
- cls, tile_list: Optional[Union[List[int]]]
81
- ) -> Optional[Union[List[int]]]:
80
+ cls, tile_list: Optional[list[int]]
81
+ ) -> Optional[list[int]]:
82
82
  """
83
83
  Validate that each entry is greater or equal than 8 and a power of 2.
84
84
 
85
85
  Parameters
86
86
  ----------
87
- tile_list : List[int]
87
+ tile_list : list of int
88
88
  Patch size.
89
89
 
90
90
  Returns
91
91
  -------
92
- List[int]
92
+ list of int
93
93
  Validated patch size.
94
94
 
95
95
  Raises
@@ -182,11 +182,23 @@ class InferenceConfig(BaseModel):
182
182
  If std is not None and mean is None.
183
183
  """
184
184
  # check that mean and std are either both None, or both specified
185
- if (self.mean is None) != (self.std is None):
185
+ if not self.image_means and not self.image_stds:
186
+ raise ValueError("Mean and std must be specified during inference.")
187
+
188
+ if (self.image_means and not self.image_stds) or (
189
+ self.image_stds and not self.image_means
190
+ ):
186
191
  raise ValueError(
187
192
  "Mean and std must be either both None, or both specified."
188
193
  )
189
194
 
195
+ elif (self.image_means is not None and self.image_stds is not None) and (
196
+ len(self.image_means) != len(self.image_stds)
197
+ ):
198
+ raise ValueError(
199
+ "Mean and std must be specified for each " "input channel."
200
+ )
201
+
190
202
  return self
191
203
 
192
204
  def _update(self, **kwargs: Any) -> None:
@@ -201,7 +213,7 @@ class InferenceConfig(BaseModel):
201
213
  self.__dict__.update(kwargs)
202
214
  self.__class__.model_validate(self.__dict__)
203
215
 
204
- def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None:
216
+ def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
205
217
  """
206
218
  Set 3D parameters.
207
219
 
@@ -209,9 +221,9 @@ class InferenceConfig(BaseModel):
209
221
  ----------
210
222
  axes : str
211
223
  Axes.
212
- tile_size : List[int]
224
+ tile_size : list of int
213
225
  Tile size.
214
- tile_overlap : List[int]
226
+ tile_overlap : list of int
215
227
  Tile overlap.
216
228
  """
217
229
  self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)