careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +83 -62
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -0
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +2 -0
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +1 -79
  11. careamics/config/configuration_model.py +12 -7
  12. careamics/config/data_model.py +29 -10
  13. careamics/config/inference_model.py +12 -2
  14. careamics/config/optimizer_models.py +6 -0
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/tile_information.py +10 -0
  17. careamics/config/training_model.py +5 -1
  18. careamics/dataset/dataset_utils/__init__.py +0 -6
  19. careamics/dataset/dataset_utils/file_utils.py +1 -1
  20. careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
  21. careamics/dataset/in_memory_dataset.py +37 -21
  22. careamics/dataset/iterable_dataset.py +38 -34
  23. careamics/dataset/iterable_pred_dataset.py +2 -1
  24. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  25. careamics/dataset/patching/patching.py +53 -37
  26. careamics/file_io/__init__.py +7 -0
  27. careamics/file_io/read/__init__.py +11 -0
  28. careamics/file_io/read/get_func.py +56 -0
  29. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
  30. careamics/file_io/write/__init__.py +9 -0
  31. careamics/file_io/write/get_func.py +59 -0
  32. careamics/file_io/write/tiff.py +39 -0
  33. careamics/lightning/__init__.py +17 -0
  34. careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
  35. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
  36. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
  37. careamics/model_io/bmz_io.py +1 -1
  38. careamics/model_io/model_io_utils.py +1 -1
  39. careamics/prediction_utils/__init__.py +0 -2
  40. careamics/prediction_utils/prediction_outputs.py +18 -46
  41. careamics/prediction_utils/stitch_prediction.py +17 -14
  42. careamics/utils/__init__.py +2 -0
  43. careamics/utils/autocorrelation.py +40 -0
  44. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
  45. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
  46. careamics/config/configuration_example.py +0 -86
  47. careamics/dataset/dataset_utils/read_utils.py +0 -27
  48. careamics/prediction_utils/create_pred_datamodule.py +0 -185
  49. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  50. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  51. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  52. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  53. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
  54. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ from typing import Literal, Union
9
9
 
10
10
  import yaml
11
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
12
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
+ from pydantic import BaseModel, ConfigDict, field_validator, model_validator
13
13
  from typing_extensions import Self
14
14
 
15
15
  from .algorithm_model import AlgorithmConfig
@@ -147,20 +147,25 @@ class Configuration(BaseModel):
147
147
  )
148
148
 
149
149
  # version
150
- version: Literal["0.1.0"] = Field(
151
- default="0.1.0", description="Version of the CAREamics configuration."
152
- )
150
+ version: Literal["0.1.0"] = "0.1.0"
151
+ """CAREamics configuration version."""
153
152
 
154
153
  # required parameters
155
- experiment_name: str = Field(
156
- ..., description="Name of the experiment, used to name logs and checkpoints."
157
- )
154
+ experiment_name: str
155
+ """Name of the experiment, used to name logs and checkpoints."""
158
156
 
159
157
  # Sub-configurations
160
158
  algorithm_config: AlgorithmConfig
159
+ """Algorithm configuration, holding all parameters required to configure the
160
+ model."""
161
161
 
162
162
  data_config: DataConfig
163
+ """Data configuration, holding all parameters required to configure the training
164
+ data loader."""
165
+
163
166
  training_config: TrainingConfig
167
+ """Training configuration, holding all parameters required to configure the
168
+ training process."""
164
169
 
165
170
  @field_validator("experiment_name")
166
171
  @classmethod
@@ -55,8 +55,8 @@ class DataConfig(BaseModel):
55
55
  ... axes="YX"
56
56
  ... )
57
57
 
58
- To change the mean and std of the data:
59
- >>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
58
+ To change the image_means and image_stds of the data:
59
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
60
60
 
61
61
  One can pass also a list of transformations, by keyword, using the
62
62
  SupportedTransform value:
@@ -80,22 +80,38 @@ class DataConfig(BaseModel):
80
80
  )
81
81
 
82
82
  # Dataset configuration
83
- data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
83
+ data_type: Literal["array", "tiff", "custom"]
84
+ """Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
85
+ in SupportedData."""
86
+
87
+ axes: str
88
+ """Axes of the data, as defined in SupportedAxes."""
89
+
84
90
  patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
91
+ """Patch size, as used during training."""
92
+
85
93
  batch_size: int = Field(default=1, ge=1, validate_default=True)
86
- axes: str
94
+ """Batch size for training."""
87
95
 
88
96
  # Optional fields
89
97
  image_means: Optional[list[float]] = Field(
90
98
  default=None, min_length=0, max_length=32
91
99
  )
100
+ """Means of the data across channels, used for normalization."""
101
+
92
102
  image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
103
+ """Standard deviations of the data across channels, used for normalization."""
104
+
93
105
  target_means: Optional[list[float]] = Field(
94
106
  default=None, min_length=0, max_length=32
95
107
  )
108
+ """Means of the target data across channels, used for normalization."""
109
+
96
110
  target_stds: Optional[list[float]] = Field(
97
111
  default=None, min_length=0, max_length=32
98
112
  )
113
+ """Standard deviations of the target data across channels, used for
114
+ normalization."""
99
115
 
100
116
  transforms: list[TRANSFORMS_UNION] = Field(
101
117
  default=[
@@ -111,8 +127,11 @@ class DataConfig(BaseModel):
111
127
  ],
112
128
  validate_default=True,
113
129
  )
130
+ """List of transformations to apply to the data, available transforms are defined
131
+ in SupportedTransform. The default values are set for Noise2Void."""
114
132
 
115
133
  dataloader_params: Optional[dict] = None
134
+ """Dictionary of PyTorch dataloader parameters."""
116
135
 
117
136
  @field_validator("patch_size")
118
137
  @classmethod
@@ -346,7 +365,7 @@ class DataConfig(BaseModel):
346
365
  if self.has_n2v_manipulate():
347
366
  self.transforms.pop(-1)
348
367
 
349
- def set_mean_and_std(
368
+ def set_means_and_stds(
350
369
  self,
351
370
  image_means: Union[NDArray, tuple, list, None],
352
371
  image_stds: Union[NDArray, tuple, list, None],
@@ -354,20 +373,20 @@ class DataConfig(BaseModel):
354
373
  target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
355
374
  ) -> None:
356
375
  """
357
- Set mean and standard deviation of the data.
376
+ Set mean and standard deviation of the data across channels.
358
377
 
359
378
  This method should be used instead setting the fields directly, as it would
360
379
  otherwise trigger a validation error.
361
380
 
362
381
  Parameters
363
382
  ----------
364
- image_means : NDArray or tuple or list
383
+ image_means : numpy.ndarray ,tuple or list
365
384
  Mean values for normalization.
366
- image_stds : NDArray or tuple or list
385
+ image_stds : numpy.ndarray, tuple or list
367
386
  Standard deviation values for normalization.
368
- target_means : NDArray or tuple or list, optional
387
+ target_means : numpy.ndarray, tuple or list, optional
369
388
  Target mean values for normalization, by default ().
370
- target_stds : NDArray or tuple or list, optional
389
+ target_stds : numpy.ndarray, tuple or list, optional
371
390
  Target standard deviation values for normalization, by default ().
372
391
  """
373
392
  # make sure we pass a list
@@ -15,25 +15,35 @@ class InferenceConfig(BaseModel):
15
15
 
16
16
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
17
17
 
18
- # Mandatory fields
19
18
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
19
+ """Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
20
+
20
21
  tile_size: Optional[Union[list[int]]] = Field(
21
22
  default=None, min_length=2, max_length=3
22
23
  )
24
+ """Tile size of prediction, only effective if `tile_overlap` is specified."""
25
+
23
26
  tile_overlap: Optional[Union[list[int]]] = Field(
24
27
  default=None, min_length=2, max_length=3
25
28
  )
29
+ """Overlap between tiles, only effective if `tile_size` is specified."""
26
30
 
27
31
  axes: str
32
+ """Data axes (TSCZYX) in the order of the input data."""
28
33
 
29
34
  image_means: list = Field(..., min_length=0, max_length=32)
35
+ """Mean values for each input channel."""
36
+
30
37
  image_stds: list = Field(..., min_length=0, max_length=32)
38
+ """Standard deviation values for each input channel."""
31
39
 
32
- # only default TTAs are supported for now
40
+ # TODO only default TTAs are supported for now
33
41
  tta_transforms: bool = Field(default=True)
42
+ """Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
34
43
 
35
44
  # Dataloader parameters
36
45
  batch_size: int = Field(default=1, ge=1)
46
+ """Batch size for prediction."""
37
47
 
38
48
  @field_validator("tile_overlap")
39
49
  @classmethod
@@ -45,6 +45,7 @@ class OptimizerModel(BaseModel):
45
45
 
46
46
  # Mandatory field
47
47
  name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
48
+ """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
48
49
 
49
50
  # Optional parameters, empty dict default value to allow filtering dictionary
50
51
  parameters: dict = Field(
@@ -53,6 +54,7 @@ class OptimizerModel(BaseModel):
53
54
  },
54
55
  validate_default=True,
55
56
  )
57
+ """Parameters of the optimizer, see PyTorch documentation for more details."""
56
58
 
57
59
  @field_validator("parameters")
58
60
  @classmethod
@@ -140,9 +142,13 @@ class LrSchedulerModel(BaseModel):
140
142
 
141
143
  # Mandatory field
142
144
  name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
145
+ """Name of the learning rate scheduler, supported schedulers are defined in
146
+ SupportedScheduler."""
143
147
 
144
148
  # Optional parameters
145
149
  parameters: dict = Field(default={}, validate_default=True)
150
+ """Parameters of the learning rate scheduler, see PyTorch documentation for more
151
+ details."""
146
152
 
147
153
  @field_validator("parameters")
148
154
  @classmethod
@@ -60,9 +60,9 @@ class SupportedData(str, BaseEnum):
60
60
  return super()._missing_(value)
61
61
 
62
62
  @classmethod
63
- def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
63
+ def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
64
64
  """
65
- Path.rglob and fnmatch compatible extension.
65
+ Get Path.rglob and fnmatch compatible extension.
66
66
 
67
67
  Parameters
68
68
  ----------
@@ -72,13 +72,38 @@ class SupportedData(str, BaseEnum):
72
72
  Returns
73
73
  -------
74
74
  str
75
- Corresponding extension.
75
+ Corresponding extension pattern.
76
76
  """
77
77
  if data_type == cls.ARRAY:
78
- raise NotImplementedError(f"Data {data_type} are not loaded from file.")
78
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
79
79
  elif data_type == cls.TIFF:
80
80
  return "*.tif*"
81
81
  elif data_type == cls.CUSTOM:
82
82
  return "*.*"
83
83
  else:
84
84
  raise ValueError(f"Data type {data_type} is not supported.")
85
+
86
+ @classmethod
87
+ def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
88
+ """
89
+ Get file extension of corresponding data type.
90
+
91
+ Parameters
92
+ ----------
93
+ data_type : str or SupportedData
94
+ Data type.
95
+
96
+ Returns
97
+ -------
98
+ str
99
+ Corresponding extension.
100
+ """
101
+ if data_type == cls.ARRAY:
102
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
103
+ elif data_type == cls.TIFF:
104
+ return ".tiff"
105
+ elif data_type == cls.CUSTOM:
106
+ # TODO: improve this message
107
+ raise NotImplementedError("Custom extensions have to be passed elsewhere.")
108
+ else:
109
+ raise ValueError(f"Data type {data_type} is not supported.")
@@ -19,10 +19,20 @@ class TileInformation(BaseModel):
19
19
  model_config = ConfigDict(validate_default=True)
20
20
 
21
21
  array_shape: tuple[int, ...]
22
+ """Shape of the original (untiled) array."""
23
+
22
24
  last_tile: bool = False
25
+ """Whether this tile is the last one of the array."""
26
+
23
27
  overlap_crop_coords: tuple[tuple[int, ...], ...]
28
+ """Inner coordinates of the tile where to crop the prediction in order to stitch
29
+ it back into the original image."""
30
+
24
31
  stitch_coords: tuple[tuple[int, ...], ...]
32
+ """Coordinates in the original image where to stitch the cropped tile back."""
33
+
25
34
  sample_id: int
35
+ """Sample ID of the tile."""
26
36
 
27
37
  @field_validator("array_shape")
28
38
  @classmethod
@@ -35,15 +35,19 @@ class TrainingConfig(BaseModel):
35
35
  )
36
36
 
37
37
  num_epochs: int = Field(default=20, ge=1)
38
+ """Number of epochs, greater than 0."""
38
39
 
39
40
  logger: Optional[Literal["wandb", "tensorboard"]] = None
41
+ """Logger to use during training. If None, no logger will be used. Available
42
+ loggers are defined in SupportedLogger."""
40
43
 
41
44
  checkpoint_callback: CheckpointModel = CheckpointModel()
45
+ """Checkpoint callback configuration."""
42
46
 
43
47
  early_stopping_callback: Optional[EarlyStoppingModel] = Field(
44
48
  default=None, validate_default=True
45
49
  )
46
- # precision: Literal["64", "32", "16", "bf16"] = 32
50
+ """Early stopping callback configuration."""
47
51
 
48
52
  def __str__(self) -> str:
49
53
  """Pretty string reprensenting the configuration.
@@ -6,9 +6,6 @@ __all__ = [
6
6
  "get_files_size",
7
7
  "list_files",
8
8
  "validate_source_target_files",
9
- "read_tiff",
10
- "get_read_func",
11
- "read_zarr",
12
9
  "iterate_over_files",
13
10
  "WelfordStatistics",
14
11
  ]
@@ -19,7 +16,4 @@ from .dataset_utils import (
19
16
  )
20
17
  from .file_utils import get_files_size, list_files, validate_source_target_files
21
18
  from .iterate_over_files import iterate_over_files
22
- from .read_tiff import read_tiff
23
- from .read_utils import get_read_func
24
- from .read_zarr import read_zarr
25
19
  from .running_stats import WelfordStatistics, compute_normalization_stats
@@ -75,7 +75,7 @@ def list_files(
75
75
  raise FileNotFoundError(f"Data path {data_path} does not exist.")
76
76
 
77
77
  # get extension compatible with fnmatch and rglob search
78
- extension = SupportedData.get_extension(data_type)
78
+ extension = SupportedData.get_extension_pattern(data_type)
79
79
 
80
80
  if data_type == SupportedData.CUSTOM and extension_filter != "":
81
81
  extension = extension_filter
@@ -9,10 +9,10 @@ from numpy.typing import NDArray
9
9
  from torch.utils.data import get_worker_info
10
10
 
11
11
  from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.file_io.read import read_tiff
12
13
  from careamics.utils.logging import get_logger
13
14
 
14
15
  from .dataset_utils import reshape_array
15
- from .read_tiff import read_tiff
16
16
 
17
17
  logger = get_logger(__name__)
18
18
 
@@ -9,14 +9,15 @@ from typing import Any, Callable, Optional, Union
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
+ from careamics.file_io.read import read_tiff
12
13
  from careamics.transforms import Compose
13
14
 
14
15
  from ..config import DataConfig
15
16
  from ..config.transformations import NormalizeModel
16
17
  from ..utils.logging import get_logger
17
- from .dataset_utils import read_tiff
18
18
  from .patching.patching import (
19
19
  PatchedOutput,
20
+ Stats,
20
21
  prepare_patches_supervised,
21
22
  prepare_patches_supervised_array,
22
23
  prepare_patches_unsupervised,
@@ -77,47 +78,50 @@ class InMemoryDataset(Dataset):
77
78
  # read function
78
79
  self.read_source_func = read_source_func
79
80
 
80
- # Generate patches
81
+ # generate patches
81
82
  supervised = self.input_targets is not None
82
83
  patches_data = self._prepare_patches(supervised)
83
84
 
84
- # Unpack the dataclass
85
+ # unpack the dataclass
85
86
  self.data = patches_data.patches
86
87
  self.data_targets = patches_data.targets
87
88
 
89
+ # set image statistics
88
90
  if self.data_config.image_means is None:
89
- self.image_means = patches_data.image_stats.means
90
- self.image_stds = patches_data.image_stats.stds
91
+ self.image_stats = patches_data.image_stats
91
92
  logger.info(
92
- f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
93
+ f"Computed dataset mean: {self.image_stats.means}, "
94
+ f"std: {self.image_stats.stds}"
93
95
  )
94
96
  else:
95
- self.image_means = self.data_config.image_means
96
- self.image_stds = self.data_config.image_stds
97
+ self.image_stats = Stats(
98
+ self.data_config.image_means, self.data_config.image_stds
99
+ )
97
100
 
101
+ # set target statistics
98
102
  if self.data_config.target_means is None:
99
- self.target_means = patches_data.target_stats.means
100
- self.target_stds = patches_data.target_stats.stds
103
+ self.target_stats = patches_data.target_stats
101
104
  else:
102
- self.target_means = self.data_config.target_means
103
- self.target_stds = self.data_config.target_stds
105
+ self.target_stats = Stats(
106
+ self.data_config.target_means, self.data_config.target_stds
107
+ )
104
108
 
105
109
  # update mean and std in configuration
106
110
  # the object is mutable and should then be recorded in the CAREamist obj
107
- self.data_config.set_mean_and_std(
108
- image_means=self.image_means,
109
- image_stds=self.image_stds,
110
- target_means=self.target_means,
111
- target_stds=self.target_stds,
111
+ self.data_config.set_means_and_stds(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
112
116
  )
113
117
  # get transforms
114
118
  self.patch_transform = Compose(
115
119
  transform_list=[
116
120
  NormalizeModel(
117
- image_means=self.image_means,
118
- image_stds=self.image_stds,
119
- target_means=self.target_means,
120
- target_stds=self.target_stds,
121
+ image_means=self.image_stats.means,
122
+ image_stds=self.image_stats.stds,
123
+ target_means=self.target_stats.means,
124
+ target_stds=self.target_stats.stds,
121
125
  )
122
126
  ]
123
127
  + self.data_config.transforms,
@@ -223,6 +227,18 @@ class InMemoryDataset(Dataset):
223
227
  "and no N2V manipulation (no N2V training)."
224
228
  )
225
229
 
230
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
231
+ """Return training data statistics.
232
+
233
+ This does not return the target data statistics, only those of the input.
234
+
235
+ Returns
236
+ -------
237
+ tuple of list of floats
238
+ Means and standard deviations across channels of the training data.
239
+ """
240
+ return self.image_stats.get_statistics()
241
+
226
242
  def split_dataset(
227
243
  self,
228
244
  percentage: float = 0.1,
@@ -12,15 +12,13 @@ from torch.utils.data import IterableDataset
12
12
 
13
13
  from careamics.config import DataConfig
14
14
  from careamics.config.transformations import NormalizeModel
15
+ from careamics.file_io.read import read_tiff
15
16
  from careamics.transforms import Compose
16
17
 
17
18
  from ..utils.logging import get_logger
18
- from .dataset_utils import (
19
- iterate_over_files,
20
- read_tiff,
21
- )
19
+ from .dataset_utils import iterate_over_files
22
20
  from .dataset_utils.running_stats import WelfordStatistics
23
- from .patching.patching import Stats, StatsOutput
21
+ from .patching.patching import Stats
24
22
  from .patching.random_patching import extract_patches_random
25
23
 
26
24
  logger = get_logger(__name__)
@@ -78,31 +76,31 @@ class PathIterableDataset(IterableDataset):
78
76
  # only checking the image_mean because the DataConfig class ensures that
79
77
  # if image_mean is provided, image_std is also provided
80
78
  if not self.data_config.image_means:
81
- self.data_stats = self._calculate_mean_and_std()
79
+ self.image_stats, self.target_stats = self._calculate_mean_and_std()
82
80
  logger.info(
83
- f"Computed dataset mean: {self.data_stats.image_stats.means},"
84
- f"std: {self.data_stats.image_stats.stds}"
81
+ f"Computed dataset mean: {self.image_stats.means},"
82
+ f"std: {self.image_stats.stds}"
85
83
  )
86
84
 
87
85
  # update the mean in the config
88
- self.data_config.set_mean_and_std(
89
- image_means=self.data_stats.image_stats.means,
90
- image_stds=self.data_stats.image_stats.stds,
86
+ self.data_config.set_means_and_stds(
87
+ image_means=self.image_stats.means,
88
+ image_stds=self.image_stats.stds,
91
89
  target_means=(
92
- list(self.data_stats.target_stats.means)
93
- if self.data_stats.target_stats.means is not None
90
+ list(self.target_stats.means)
91
+ if self.target_stats.means is not None
94
92
  else None
95
93
  ),
96
94
  target_stds=(
97
- list(self.data_stats.target_stats.stds)
98
- if self.data_stats.target_stats.stds is not None
95
+ list(self.target_stats.stds)
96
+ if self.target_stats.stds is not None
99
97
  else None
100
98
  ),
101
99
  )
102
100
 
103
101
  else:
104
102
  # if mean and std are provided in the config, use them
105
- self.data_stats = StatsOutput(
103
+ self.image_stats, self.target_stats = (
106
104
  Stats(self.data_config.image_means, self.data_config.image_stds),
107
105
  Stats(self.data_config.target_means, self.data_config.target_stds),
108
106
  )
@@ -111,23 +109,23 @@ class PathIterableDataset(IterableDataset):
111
109
  self.patch_transform = Compose(
112
110
  transform_list=[
113
111
  NormalizeModel(
114
- image_means=self.data_stats.image_stats.means,
115
- image_stds=self.data_stats.image_stats.stds,
116
- target_means=self.data_stats.target_stats.means,
117
- target_stds=self.data_stats.target_stats.stds,
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
118
116
  )
119
117
  ]
120
118
  + data_config.transforms
121
119
  )
122
120
 
123
- def _calculate_mean_and_std(self) -> StatsOutput:
121
+ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
124
122
  """
125
123
  Calculate mean and std of the dataset.
126
124
 
127
125
  Returns
128
126
  -------
129
- PatchedOutput
130
- Data class containing the image statistics.
127
+ tuple of Stats and optional Stats
128
+ Data classes containing the image and target statistics.
131
129
  """
132
130
  num_samples = 0
133
131
  image_stats = WelfordStatistics()
@@ -155,15 +153,12 @@ class PathIterableDataset(IterableDataset):
155
153
  if target is not None:
156
154
  target_means, target_stds = target_stats.finalize()
157
155
 
158
- logger.info(f"Calculated mean and std for {num_samples} images")
159
- logger.info(f"Mean: {image_means}, std: {image_stds}")
160
- return StatsOutput(
161
- Stats(image_means, image_stds),
162
- Stats(
163
- np.array(target_means) if target is not None else None,
164
- np.array(target_stds) if target is not None else None,
165
- ),
166
- )
156
+ return (
157
+ Stats(image_means, image_stds),
158
+ Stats(np.array(target_means), np.array(target_stds)),
159
+ )
160
+ else:
161
+ return Stats(image_means, image_stds), Stats(None, None)
167
162
 
168
163
  def __iter__(
169
164
  self,
@@ -177,8 +172,7 @@ class PathIterableDataset(IterableDataset):
177
172
  Single patch.
178
173
  """
179
174
  assert (
180
- self.data_stats.image_stats.means is not None
181
- and self.data_stats.image_stats.stds is not None
175
+ self.image_stats.means is not None and self.image_stats.stds is not None
182
176
  ), "Mean and std must be provided"
183
177
 
184
178
  # iterate over files
@@ -201,6 +195,16 @@ class PathIterableDataset(IterableDataset):
201
195
  target=patch_data[1],
202
196
  )
203
197
 
198
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
199
+ """Return training data statistics.
200
+
201
+ Returns
202
+ -------
203
+ tuple of list of floats
204
+ Means and standard deviations across channels of the training data.
205
+ """
206
+ return self.image_stats.get_statistics()
207
+
204
208
  def get_number_of_files(self) -> int:
205
209
  """
206
210
  Return the number of files in the dataset.
@@ -8,11 +8,12 @@ from typing import Any, Callable, Generator
8
8
  from numpy.typing import NDArray
9
9
  from torch.utils.data import IterableDataset
10
10
 
11
+ from careamics.file_io.read import read_tiff
11
12
  from careamics.transforms import Compose
12
13
 
13
14
  from ..config import InferenceConfig
14
15
  from ..config.transformations import NormalizeModel
15
- from .dataset_utils import iterate_over_files, read_tiff
16
+ from .dataset_utils import iterate_over_files
16
17
 
17
18
 
18
19
  class IterablePredDataset(IterableDataset):
@@ -8,12 +8,13 @@ from typing import Any, Callable, Generator
8
8
  from numpy.typing import NDArray
9
9
  from torch.utils.data import IterableDataset
10
10
 
11
+ from careamics.file_io.read import read_tiff
11
12
  from careamics.transforms import Compose
12
13
 
13
14
  from ..config import InferenceConfig
14
15
  from ..config.tile_information import TileInformation
15
16
  from ..config.transformations import NormalizeModel
16
- from .dataset_utils import iterate_over_files, read_tiff
17
+ from .dataset_utils import iterate_over_files
17
18
  from .tiling import extract_tiles
18
19
 
19
20