careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,129 @@
1
+ """In-memory tiled prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.tile_information import TileInformation
12
+ from ..config.transformations import NormalizeModel
13
+ from .dataset_utils import reshape_array
14
+ from .tiling import extract_tiles
15
+
16
+
17
+ class InMemoryTiledPredDataset(Dataset):
18
+ """Prediction dataset storing data in memory and returning tiles of each image.
19
+
20
+ Parameters
21
+ ----------
22
+ prediction_config : InferenceConfig
23
+ Prediction configuration.
24
+ inputs : NDArray
25
+ Input data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ prediction_config: InferenceConfig,
31
+ inputs: NDArray,
32
+ ) -> None:
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ prediction_config : InferenceConfig
38
+ Prediction configuration.
39
+ inputs : NDArray
40
+ Input data.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If data_path is not a directory.
46
+ """
47
+ if (
48
+ prediction_config.tile_size is None
49
+ or prediction_config.tile_overlap is None
50
+ ):
51
+ raise ValueError(
52
+ "Tile size and overlap must be provided to use the tiled prediction "
53
+ "dataset."
54
+ )
55
+
56
+ self.pred_config = prediction_config
57
+ self.input_array = inputs
58
+ self.axes = self.pred_config.axes
59
+ self.tile_size = prediction_config.tile_size
60
+ self.tile_overlap = prediction_config.tile_overlap
61
+ self.image_means = self.pred_config.image_means
62
+ self.image_stds = self.pred_config.image_stds
63
+
64
+ # Generate patches
65
+ self.data = self._prepare_tiles()
66
+
67
+ # get transforms
68
+ self.patch_transform = Compose(
69
+ transform_list=[
70
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
71
+ ],
72
+ )
73
+
74
+ def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
75
+ """
76
+ Iterate over data source and create an array of patches.
77
+
78
+ Returns
79
+ -------
80
+ list of tuples of NDArray and TileInformation
81
+ List of tiles and tile information.
82
+ """
83
+ # reshape array
84
+ reshaped_sample = reshape_array(self.input_array, self.axes)
85
+
86
+ # generate patches, which returns a generator
87
+ patch_generator = extract_tiles(
88
+ arr=reshaped_sample,
89
+ tile_size=self.tile_size,
90
+ overlaps=self.tile_overlap,
91
+ )
92
+ patches_list = list(patch_generator)
93
+
94
+ if len(patches_list) == 0:
95
+ raise ValueError("No tiles generated, ")
96
+
97
+ return patches_list
98
+
99
+ def __len__(self) -> int:
100
+ """
101
+ Return the length of the dataset.
102
+
103
+ Returns
104
+ -------
105
+ int
106
+ Length of the dataset.
107
+ """
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
111
+ """
112
+ Return the patch corresponding to the provided index.
113
+
114
+ Parameters
115
+ ----------
116
+ index : int
117
+ Index of the patch to return.
118
+
119
+ Returns
120
+ -------
121
+ tuple of NDArray and TileInformation
122
+ Transformed patch.
123
+ """
124
+ tile_array, tile_info = self.data[index]
125
+
126
+ # Apply transforms
127
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
128
+
129
+ return transformed_tile, tile_info
@@ -1,20 +1,27 @@
1
+ """Iterable dataset used to load data file by file."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import copy
6
+ from collections.abc import Generator
4
7
  from pathlib import Path
5
- from typing import Any, Callable, Generator, List, Optional, Tuple, Union
8
+ from typing import Callable, Optional
6
9
 
7
10
  import numpy as np
8
- from torch.utils.data import IterableDataset, get_worker_info
11
+ from torch.utils.data import IterableDataset
9
12
 
13
+ from careamics.config import DataConfig
14
+ from careamics.config.transformations import NormalizeModel
10
15
  from careamics.transforms import Compose
11
16
 
12
- from ..config import DataConfig, InferenceConfig
13
- from ..config.tile_information import TileInformation
14
17
  from ..utils.logging import get_logger
15
- from .dataset_utils import read_tiff, reshape_array
18
+ from .dataset_utils import (
19
+ iterate_over_files,
20
+ read_tiff,
21
+ )
22
+ from .dataset_utils.running_stats import WelfordStatistics
23
+ from .patching.patching import Stats, StatsOutput
16
24
  from .patching.random_patching import extract_patches_random
17
- from .patching.tiled_patching import extract_tiles
18
25
 
19
26
  logger = get_logger(__name__)
20
27
 
@@ -25,129 +32,142 @@ class PathIterableDataset(IterableDataset):
25
32
 
26
33
  Parameters
27
34
  ----------
28
- data_path : Union[str, Path]
35
+ data_config : DataConfig
36
+ Data configuration.
37
+ src_files : list of pathlib.Path
38
+ List of data files.
39
+ target_files : list of pathlib.Path, optional
40
+ Optional list of target files, by default None.
41
+ read_source_func : Callable, optional
42
+ Read source function for custom types, by default read_tiff.
43
+
44
+ Attributes
45
+ ----------
46
+ data_path : list of pathlib.Path
29
47
  Path to the data, must be a directory.
30
48
  axes : str
31
49
  Description of axes in format STCZYX.
32
- patch_extraction_method : Union[ExtractionStrategies, None]
33
- Patch extraction strategy, as defined in extraction_strategy.
34
- patch_size : Optional[Union[List[int], Tuple[int]]], optional
35
- Size of the patches in each dimension, by default None.
36
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
37
- Overlap of the patches in each dimension, by default None.
38
- mean : Optional[float], optional
39
- Expected mean of the dataset, by default None.
40
- std : Optional[float], optional
41
- Expected standard deviation of the dataset, by default None.
42
- patch_transform : Optional[Callable], optional
43
- Patch transform callable, by default None.
44
50
  """
45
51
 
46
52
  def __init__(
47
53
  self,
48
- data_config: Union[DataConfig, InferenceConfig],
49
- src_files: List[Path],
50
- target_files: Optional[List[Path]] = None,
54
+ data_config: DataConfig,
55
+ src_files: list[Path],
56
+ target_files: Optional[list[Path]] = None,
51
57
  read_source_func: Callable = read_tiff,
52
58
  ) -> None:
59
+ """Constructors.
60
+
61
+ Parameters
62
+ ----------
63
+ data_config : DataConfig
64
+ Data configuration.
65
+ src_files : list[Path]
66
+ List of data files.
67
+ target_files : list[Path] or None, optional
68
+ Optional list of target files, by default None.
69
+ read_source_func : Callable, optional
70
+ Read source function for custom types, by default read_tiff.
71
+ """
53
72
  self.data_config = data_config
54
73
  self.data_files = src_files
55
74
  self.target_files = target_files
56
- self.data_config = data_config
57
75
  self.read_source_func = read_source_func
58
76
 
59
77
  # compute mean and std over the dataset
60
- if not data_config.mean or not data_config.std:
61
- self.mean, self.std = self._calculate_mean_and_std()
78
+ # only checking the image_mean because the DataConfig class ensures that
79
+ # if image_mean is provided, image_std is also provided
80
+ if not self.data_config.image_means:
81
+ self.data_stats = self._calculate_mean_and_std()
82
+ logger.info(
83
+ f"Computed dataset mean: {self.data_stats.image_stats.means},"
84
+ f"std: {self.data_stats.image_stats.stds}"
85
+ )
86
+
87
+ # update the mean in the config
88
+ self.data_config.set_mean_and_std(
89
+ image_means=self.data_stats.image_stats.means,
90
+ image_stds=self.data_stats.image_stats.stds,
91
+ target_means=(
92
+ list(self.data_stats.target_stats.means)
93
+ if self.data_stats.target_stats.means is not None
94
+ else None
95
+ ),
96
+ target_stds=(
97
+ list(self.data_stats.target_stats.stds)
98
+ if self.data_stats.target_stats.stds is not None
99
+ else None
100
+ ),
101
+ )
62
102
 
63
- # update mean and std in configuration
64
- # the object is mutable and should then be recorded in the CAREamist
65
- data_config.set_mean_and_std(self.mean, self.std)
66
103
  else:
67
- self.mean = data_config.mean
68
- self.std = data_config.std
104
+ # if mean and std are provided in the config, use them
105
+ self.data_stats = StatsOutput(
106
+ Stats(self.data_config.image_means, self.data_config.image_stds),
107
+ Stats(self.data_config.target_means, self.data_config.target_stds),
108
+ )
69
109
 
70
- # get transforms
71
- self.patch_transform = Compose(transform_list=data_config.transforms)
110
+ # create transform composed of normalization and other transforms
111
+ self.patch_transform = Compose(
112
+ transform_list=[
113
+ NormalizeModel(
114
+ image_means=self.data_stats.image_stats.means,
115
+ image_stds=self.data_stats.image_stats.stds,
116
+ target_means=self.data_stats.target_stats.means,
117
+ target_stds=self.data_stats.target_stats.stds,
118
+ )
119
+ ]
120
+ + data_config.transforms
121
+ )
72
122
 
73
- def _calculate_mean_and_std(self) -> Tuple[float, float]:
123
+ def _calculate_mean_and_std(self) -> StatsOutput:
74
124
  """
75
125
  Calculate mean and std of the dataset.
76
126
 
77
127
  Returns
78
128
  -------
79
- Tuple[float, float]
80
- Tuple containing mean and standard deviation.
129
+ PatchedOutput
130
+ Data class containing the image statistics.
81
131
  """
82
- means, stds = 0, 0
83
132
  num_samples = 0
133
+ image_stats = WelfordStatistics()
134
+ if self.target_files is not None:
135
+ target_stats = WelfordStatistics()
136
+
137
+ for sample, target in iterate_over_files(
138
+ self.data_config, self.data_files, self.target_files, self.read_source_func
139
+ ):
140
+ # update the image statistics
141
+ image_stats.update(sample, num_samples)
142
+
143
+ # update the target statistics if target is available
144
+ if target is not None:
145
+ target_stats.update(target, num_samples)
84
146
 
85
- for sample, _ in self._iterate_over_files():
86
- means += sample.mean()
87
- stds += sample.std()
88
147
  num_samples += 1
89
148
 
90
149
  if num_samples == 0:
91
150
  raise ValueError("No samples found in the dataset.")
92
151
 
93
- result_mean = means / num_samples
94
- result_std = stds / num_samples
95
-
96
- logger.info(f"Calculated mean and std for {num_samples} images")
97
- logger.info(f"Mean: {result_mean}, std: {result_std}")
98
- return result_mean, result_std
152
+ # Average the means and stds per sample
153
+ image_means, image_stds = image_stats.finalize()
99
154
 
100
- def _iterate_over_files(
101
- self,
102
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
103
- """
104
- Iterate over data source and yield whole image.
155
+ if target is not None:
156
+ target_means, target_stds = target_stats.finalize()
105
157
 
106
- Yields
107
- ------
108
- np.ndarray
109
- Image.
110
- """
111
- # When num_workers > 0, each worker process will have a different copy of the
112
- # dataset object
113
- # Configuring each copy independently to avoid having duplicate data returned
114
- # from the workers
115
- worker_info = get_worker_info()
116
- worker_id = worker_info.id if worker_info is not None else 0
117
- num_workers = worker_info.num_workers if worker_info is not None else 1
118
-
119
- # iterate over the files
120
- for i, filename in enumerate(self.data_files):
121
- # retrieve file corresponding to the worker id
122
- if i % num_workers == worker_id:
123
- try:
124
- # read data
125
- sample = self.read_source_func(filename, self.data_config.axes)
126
-
127
- # read target, if available
128
- if self.target_files is not None:
129
- if filename.name != self.target_files[i].name:
130
- raise ValueError(
131
- f"File {filename} does not match target file "
132
- f"{self.target_files[i]}. Have you passed sorted "
133
- f"arrays?"
134
- )
135
-
136
- # read target
137
- target = self.read_source_func(
138
- self.target_files[i], self.data_config.axes
139
- )
140
-
141
- yield sample, target
142
- else:
143
- yield sample, None
144
-
145
- except Exception as e:
146
- logger.error(f"Error reading file {filename}: {e}")
158
+ logger.info(f"Calculated mean and std for {num_samples} images")
159
+ logger.info(f"Mean: {image_means}, std: {image_stds}")
160
+ return StatsOutput(
161
+ Stats(image_means, image_stds),
162
+ Stats(
163
+ np.array(target_means) if target is not None else None,
164
+ np.array(target_stds) if target is not None else None,
165
+ ),
166
+ )
147
167
 
148
168
  def __iter__(
149
169
  self,
150
- ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]:
170
+ ) -> Generator[tuple[np.ndarray, ...], None, None]:
151
171
  """
152
172
  Iterate over data source and yield single patch.
153
173
 
@@ -157,22 +177,18 @@ class PathIterableDataset(IterableDataset):
157
177
  Single patch.
158
178
  """
159
179
  assert (
160
- self.mean is not None and self.std is not None
180
+ self.data_stats.image_stats.means is not None
181
+ and self.data_stats.image_stats.stds is not None
161
182
  ), "Mean and std must be provided"
162
183
 
163
184
  # iterate over files
164
- for sample_input, sample_target in self._iterate_over_files():
165
- reshaped_sample = reshape_array(sample_input, self.data_config.axes)
166
- reshaped_target = (
167
- None
168
- if sample_target is None
169
- else reshape_array(sample_target, self.data_config.axes)
170
- )
171
-
185
+ for sample_input, sample_target in iterate_over_files(
186
+ self.data_config, self.data_files, self.target_files, self.read_source_func
187
+ ):
172
188
  patches = extract_patches_random(
173
- arr=reshaped_sample,
189
+ arr=sample_input,
174
190
  patch_size=self.data_config.patch_size,
175
- target=reshaped_target,
191
+ target=sample_target,
176
192
  )
177
193
 
178
194
  # iterate over patches
@@ -209,9 +225,9 @@ class PathIterableDataset(IterableDataset):
209
225
  Parameters
210
226
  ----------
211
227
  percentage : float, optional
212
- Percentage of files to split up, by default 0.1
228
+ Percentage of files to split up, by default 0.1.
213
229
  minimum_number : int, optional
214
- Minimum number of files to split up, by default 5
230
+ Minimum number of files to split up, by default 5.
215
231
 
216
232
  Returns
217
233
  -------
@@ -273,89 +289,3 @@ class PathIterableDataset(IterableDataset):
273
289
  dataset.target_files = val_target_files
274
290
 
275
291
  return dataset
276
-
277
-
278
- class IterablePredictionDataset(PathIterableDataset):
279
- """
280
- Dataset allowing extracting patches w/o loading whole data into memory.
281
-
282
- Parameters
283
- ----------
284
- data_path : Union[str, Path]
285
- Path to the data, must be a directory.
286
- axes : str
287
- Description of axes in format STCZYX.
288
- mean : Optional[float], optional
289
- Expected mean of the dataset, by default None.
290
- std : Optional[float], optional
291
- Expected standard deviation of the dataset, by default None.
292
- patch_transform : Optional[Callable], optional
293
- Patch transform callable, by default None.
294
- """
295
-
296
- def __init__(
297
- self,
298
- prediction_config: InferenceConfig,
299
- src_files: List[Path],
300
- read_source_func: Callable = read_tiff,
301
- **kwargs: Any,
302
- ) -> None:
303
- super().__init__(
304
- data_config=prediction_config,
305
- src_files=src_files,
306
- read_source_func=read_source_func,
307
- )
308
-
309
- self.prediction_config = prediction_config
310
- self.axes = prediction_config.axes
311
- self.tile_size = self.prediction_config.tile_size
312
- self.tile_overlap = self.prediction_config.tile_overlap
313
- self.read_source_func = read_source_func
314
-
315
- # tile only if both tile size and overlaps are provided
316
- self.tile = self.tile_size is not None and self.tile_overlap is not None
317
-
318
- # get tta transforms
319
- self.patch_transform = Compose(
320
- transform_list=prediction_config.transforms,
321
- )
322
-
323
- def __iter__(
324
- self,
325
- ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
326
- """
327
- Iterate over data source and yield single patch.
328
-
329
- Yields
330
- ------
331
- np.ndarray
332
- Single patch.
333
- """
334
- assert (
335
- self.mean is not None and self.std is not None
336
- ), "Mean and std must be provided"
337
-
338
- for sample, _ in self._iterate_over_files():
339
- # reshape array
340
- reshaped_sample = reshape_array(sample, self.axes)
341
-
342
- if self.tile:
343
- # generate patches, return a generator
344
- patch_gen = extract_tiles(
345
- arr=reshaped_sample,
346
- tile_size=self.tile_size,
347
- overlaps=self.tile_overlap,
348
- )
349
- else:
350
- # just wrap the sample in a generator with default tiling info
351
- array_shape = reshaped_sample.squeeze().shape
352
- patch_gen = (
353
- (reshaped_sample, TileInformation(array_shape=array_shape))
354
- for _ in range(1)
355
- )
356
-
357
- # apply transform to patches
358
- for patch_array, tile_info in patch_gen:
359
- transformed_patch, _ = self.patch_transform(patch=patch_array)
360
-
361
- yield transformed_patch, tile_info
@@ -0,0 +1,121 @@
1
+ """Iterable prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.transforms import Compose
12
+
13
+ from ..config import InferenceConfig
14
+ from ..config.transformations import NormalizeModel
15
+ from .dataset_utils import iterate_over_files, read_tiff
16
+
17
+
18
+ class IterablePredDataset(IterableDataset):
19
+ """Simple iterable prediction dataset.
20
+
21
+ Parameters
22
+ ----------
23
+ prediction_config : InferenceConfig
24
+ Inference configuration.
25
+ src_files : List[Path]
26
+ List of data files.
27
+ read_source_func : Callable, optional
28
+ Read source function for custom types, by default read_tiff.
29
+ **kwargs : Any
30
+ Additional keyword arguments, unused.
31
+
32
+ Attributes
33
+ ----------
34
+ data_path : Union[str, Path]
35
+ Path to the data, must be a directory.
36
+ axes : str
37
+ Description of axes in format STCZYX.
38
+ mean : Optional[float], optional
39
+ Expected mean of the dataset, by default None.
40
+ std : Optional[float], optional
41
+ Expected standard deviation of the dataset, by default None.
42
+ patch_transform : Optional[Callable], optional
43
+ Patch transform callable, by default None.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ prediction_config: InferenceConfig,
49
+ src_files: list[Path],
50
+ read_source_func: Callable = read_tiff,
51
+ **kwargs: Any,
52
+ ) -> None:
53
+ """Constructor.
54
+
55
+ Parameters
56
+ ----------
57
+ prediction_config : InferenceConfig
58
+ Inference configuration.
59
+ src_files : list of pathlib.Path
60
+ List of data files.
61
+ read_source_func : Callable, optional
62
+ Read source function for custom types, by default read_tiff.
63
+ **kwargs : Any
64
+ Additional keyword arguments, unused.
65
+
66
+ Raises
67
+ ------
68
+ ValueError
69
+ If mean and std are not provided in the inference configuration.
70
+ """
71
+ self.prediction_config = prediction_config
72
+ self.data_files = src_files
73
+ self.axes = prediction_config.axes
74
+ self.read_source_func = read_source_func
75
+
76
+ # check mean and std and create normalize transform
77
+ if (
78
+ self.prediction_config.image_means is None
79
+ or self.prediction_config.image_stds is None
80
+ ):
81
+ raise ValueError("Mean and std must be provided for prediction.")
82
+ else:
83
+ self.image_means = self.prediction_config.image_means
84
+ self.image_stds = self.prediction_config.image_stds
85
+
86
+ # instantiate normalize transform
87
+ self.patch_transform = Compose(
88
+ transform_list=[
89
+ NormalizeModel(
90
+ image_means=self.image_means,
91
+ image_stds=self.image_stds,
92
+ )
93
+ ],
94
+ )
95
+
96
+ def __iter__(
97
+ self,
98
+ ) -> Generator[NDArray, None, None]:
99
+ """
100
+ Iterate over data source and yield single patch.
101
+
102
+ Yields
103
+ ------
104
+ NDArray
105
+ Single patch.
106
+ """
107
+ assert (
108
+ self.image_means is not None and self.image_stds is not None
109
+ ), "Mean and std must be provided"
110
+
111
+ for sample, _ in iterate_over_files(
112
+ self.prediction_config,
113
+ self.data_files,
114
+ read_source_func=self.read_source_func,
115
+ ):
116
+ # sample has S dimension
117
+ for i in range(sample.shape[0]):
118
+
119
+ transformed_sample, _ = self.patch_transform(patch=sample[i])
120
+
121
+ yield transformed_sample