careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,186 @@
1
+ """Computing data statistics."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
8
+ """
9
+ Compute mean and standard deviation of an array.
10
+
11
+ Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
12
+ computed per channel.
13
+
14
+ Parameters
15
+ ----------
16
+ image : NDArray
17
+ Input array.
18
+
19
+ Returns
20
+ -------
21
+ tuple of (list of floats, list of floats)
22
+ Lists of mean and standard deviation values per channel.
23
+ """
24
+ # Define the list of axes excluding the channel axis
25
+ axes = tuple(np.delete(np.arange(image.ndim), 1))
26
+ return np.mean(image, axis=axes), np.std(image, axis=axes)
27
+
28
+
29
+ def update_iterative_stats(
30
+ count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
31
+ ) -> tuple[NDArray, NDArray, NDArray]:
32
+ """Update the mean and variance of an array iteratively.
33
+
34
+ Parameters
35
+ ----------
36
+ count : NDArray
37
+ Number of elements in the array.
38
+ mean : NDArray
39
+ Mean of the array.
40
+ m2 : NDArray
41
+ Variance of the array.
42
+ new_values : NDArray
43
+ New values to add to the mean and variance.
44
+
45
+ Returns
46
+ -------
47
+ tuple[NDArray, NDArray, NDArray]
48
+ Updated count, mean, and variance.
49
+ """
50
+ count += np.array([np.prod(channel.shape) for channel in new_values])
51
+ # newvalues - oldMean
52
+ delta = [
53
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
54
+ for v, m in zip(new_values, mean)
55
+ ]
56
+
57
+ mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
+ # newvalues - newMeant
59
+ delta2 = [
60
+ np.subtract(v.flatten(), [m] * len(v.flatten()))
61
+ for v, m in zip(new_values, mean)
62
+ ]
63
+
64
+ m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
65
+
66
+ return (count, mean, m2)
67
+
68
+
69
+ def finalize_iterative_stats(
70
+ count: NDArray, mean: NDArray, m2: NDArray
71
+ ) -> tuple[NDArray, NDArray]:
72
+ """Finalize the mean and variance computation.
73
+
74
+ Parameters
75
+ ----------
76
+ count : NDArray
77
+ Number of elements in the array.
78
+ mean : NDArray
79
+ Mean of the array.
80
+ m2 : NDArray
81
+ Variance of the array.
82
+
83
+ Returns
84
+ -------
85
+ tuple[NDArray, NDArray]
86
+ Final mean and standard deviation.
87
+ """
88
+ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
89
+ if any(c < 2 for c in count):
90
+ return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
+ else:
92
+ return mean, std
93
+
94
+
95
+ class WelfordStatistics:
96
+ """Compute Welford statistics iteratively.
97
+
98
+ The Welford algorithm is used to compute the mean and variance of an array
99
+ iteratively. Based on the implementation from:
100
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
101
+ """
102
+
103
+ def update(self, array: NDArray, sample_idx: int) -> None:
104
+ """Update the Welford statistics.
105
+
106
+ Parameters
107
+ ----------
108
+ array : NDArray
109
+ Input array.
110
+ sample_idx : int
111
+ Current sample number.
112
+ """
113
+ self.sample_idx = sample_idx
114
+ sample_channels = np.array(np.split(array, array.shape[1], axis=1))
115
+
116
+ # Initialize the statistics
117
+ if self.sample_idx == 0:
118
+ # Compute the mean and standard deviation
119
+ self.mean, _ = compute_normalization_stats(array)
120
+ # Initialize the count and m2 with zero-valued arrays of shape (C,)
121
+ self.count, self.mean, self.m2 = update_iterative_stats(
122
+ count=np.zeros(array.shape[1]),
123
+ mean=self.mean,
124
+ m2=np.zeros(array.shape[1]),
125
+ new_values=sample_channels,
126
+ )
127
+ else:
128
+ # Update the statistics
129
+ self.count, self.mean, self.m2 = update_iterative_stats(
130
+ count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
131
+ )
132
+
133
+ self.sample_idx += 1
134
+
135
+ def finalize(self) -> tuple[NDArray, NDArray]:
136
+ """Finalize the Welford statistics.
137
+
138
+ Returns
139
+ -------
140
+ tuple or numpy arrays
141
+ Final mean and standard deviation.
142
+ """
143
+ return finalize_iterative_stats(self.count, self.mean, self.m2)
144
+
145
+
146
+ # from multiprocessing import Value
147
+ # from typing import tuple
148
+
149
+ # import numpy as np
150
+
151
+
152
+ # class RunningStats:
153
+ # """Calculates running mean and std."""
154
+
155
+ # def __init__(self) -> None:
156
+ # self.reset()
157
+
158
+ # def reset(self) -> None:
159
+ # """Reset the running stats."""
160
+ # self.avg_mean = Value("d", 0)
161
+ # self.avg_std = Value("d", 0)
162
+ # self.m2 = Value("d", 0)
163
+ # self.count = Value("i", 0)
164
+
165
+ # def init(self, mean: float, std: float) -> None:
166
+ # """Initialize running stats."""
167
+ # with self.avg_mean.get_lock():
168
+ # self.avg_mean.value += mean
169
+ # with self.avg_std.get_lock():
170
+ # self.avg_std.value = std
171
+
172
+ # def compute_std(self) -> tuple[float, float]:
173
+ # """Compute std."""
174
+ # if self.count.value >= 2:
175
+ # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
176
+
177
+ # def update(self, value: float) -> None:
178
+ # """Update running stats."""
179
+ # with self.count.get_lock():
180
+ # self.count.value += 1
181
+ # delta = value - self.avg_mean.value
182
+ # with self.avg_mean.get_lock():
183
+ # self.avg_mean.value += delta / self.count.value
184
+ # delta2 = value - self.avg_mean.value
185
+ # with self.m2.get_lock():
186
+ # self.m2.value += delta * delta2
@@ -4,47 +4,73 @@ from __future__ import annotations
4
4
 
5
5
  import copy
6
6
  from pathlib import Path
7
- from typing import Any, Callable, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
12
  from careamics.transforms import Compose
13
13
 
14
- from ..config import DataConfig, InferenceConfig
15
- from ..config.tile_information import TileInformation
14
+ from ..config import DataConfig
15
+ from ..config.transformations import NormalizeModel
16
16
  from ..utils.logging import get_logger
17
- from .dataset_utils import read_tiff, reshape_array
17
+ from .dataset_utils import read_tiff
18
18
  from .patching.patching import (
19
+ PatchedOutput,
19
20
  prepare_patches_supervised,
20
21
  prepare_patches_supervised_array,
21
22
  prepare_patches_unsupervised,
22
23
  prepare_patches_unsupervised_array,
23
24
  )
24
- from .patching.tiled_patching import extract_tiles
25
25
 
26
26
  logger = get_logger(__name__)
27
27
 
28
28
 
29
29
  class InMemoryDataset(Dataset):
30
- """Dataset storing data in memory and allowing generating patches from it."""
30
+ """Dataset storing data in memory and allowing generating patches from it.
31
+
32
+ Parameters
33
+ ----------
34
+ data_config : CAREamics DataConfig
35
+ (see careamics.config.data_model.DataConfig)
36
+ Data configuration.
37
+ inputs : numpy.ndarray or list[pathlib.Path]
38
+ Input data.
39
+ input_target : numpy.ndarray or list[pathlib.Path], optional
40
+ Target data, by default None.
41
+ read_source_func : Callable, optional
42
+ Read source function for custom types, by default read_tiff.
43
+ **kwargs : Any
44
+ Additional keyword arguments, unused.
45
+ """
31
46
 
32
47
  def __init__(
33
48
  self,
34
49
  data_config: DataConfig,
35
- inputs: Union[np.ndarray, List[Path]],
36
- data_target: Optional[Union[np.ndarray, List[Path]]] = None,
50
+ inputs: Union[np.ndarray, list[Path]],
51
+ input_target: Optional[Union[np.ndarray, list[Path]]] = None,
37
52
  read_source_func: Callable = read_tiff,
38
53
  **kwargs: Any,
39
54
  ) -> None:
40
55
  """
41
56
  Constructor.
42
57
 
43
- # TODO
58
+ Parameters
59
+ ----------
60
+ data_config : DataConfig
61
+ Data configuration.
62
+ inputs : numpy.ndarray or list[pathlib.Path]
63
+ Input data.
64
+ input_target : numpy.ndarray or list[pathlib.Path], optional
65
+ Target data, by default None.
66
+ read_source_func : Callable, optional
67
+ Read source function for custom types, by default read_tiff.
68
+ **kwargs : Any
69
+ Additional keyword arguments, unused.
44
70
  """
45
71
  self.data_config = data_config
46
72
  self.inputs = inputs
47
- self.data_target = data_target
73
+ self.input_targets = input_target
48
74
  self.axes = self.data_config.axes
49
75
  self.patch_size = self.data_config.patch_size
50
76
 
@@ -52,30 +78,52 @@ class InMemoryDataset(Dataset):
52
78
  self.read_source_func = read_source_func
53
79
 
54
80
  # Generate patches
55
- supervised = self.data_target is not None
56
- patches = self._prepare_patches(supervised)
57
-
58
- # Add results to members
59
- self.data, self.data_targets, computed_mean, computed_std = patches
60
-
61
- if not self.data_config.mean or not self.data_config.std:
62
- self.mean, self.std = computed_mean, computed_std
63
- logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
64
-
65
- # update mean and std in configuration
66
- # the object is mutable and should then be recorded in the CAREamist obj
67
- self.data_config.set_mean_and_std(self.mean, self.std)
81
+ supervised = self.input_targets is not None
82
+ patches_data = self._prepare_patches(supervised)
83
+
84
+ # Unpack the dataclass
85
+ self.data = patches_data.patches
86
+ self.data_targets = patches_data.targets
87
+
88
+ if self.data_config.image_means is None:
89
+ self.image_means = patches_data.image_stats.means
90
+ self.image_stds = patches_data.image_stats.stds
91
+ logger.info(
92
+ f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
93
+ )
68
94
  else:
69
- self.mean, self.std = self.data_config.mean, self.data_config.std
95
+ self.image_means = self.data_config.image_means
96
+ self.image_stds = self.data_config.image_stds
70
97
 
98
+ if self.data_config.target_means is None:
99
+ self.target_means = patches_data.target_stats.means
100
+ self.target_stds = patches_data.target_stats.stds
101
+ else:
102
+ self.target_means = self.data_config.target_means
103
+ self.target_stds = self.data_config.target_stds
104
+
105
+ # update mean and std in configuration
106
+ # the object is mutable and should then be recorded in the CAREamist obj
107
+ self.data_config.set_mean_and_std(
108
+ image_means=self.image_means,
109
+ image_stds=self.image_stds,
110
+ target_means=self.target_means,
111
+ target_stds=self.target_stds,
112
+ )
71
113
  # get transforms
72
114
  self.patch_transform = Compose(
73
- transform_list=self.data_config.transforms,
115
+ transform_list=[
116
+ NormalizeModel(
117
+ image_means=self.image_means,
118
+ image_stds=self.image_stds,
119
+ target_means=self.target_means,
120
+ target_stds=self.target_stds,
121
+ )
122
+ ]
123
+ + self.data_config.transforms,
74
124
  )
75
125
 
76
- def _prepare_patches(
77
- self, supervised: bool
78
- ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
126
+ def _prepare_patches(self, supervised: bool) -> PatchedOutput:
79
127
  """
80
128
  Iterate over data source and create an array of patches.
81
129
 
@@ -86,23 +134,23 @@ class InMemoryDataset(Dataset):
86
134
 
87
135
  Returns
88
136
  -------
89
- np.ndarray
137
+ numpy.ndarray
90
138
  Array of patches.
91
139
  """
92
140
  if supervised:
93
141
  if isinstance(self.inputs, np.ndarray) and isinstance(
94
- self.data_target, np.ndarray
142
+ self.input_targets, np.ndarray
95
143
  ):
96
144
  return prepare_patches_supervised_array(
97
145
  self.inputs,
98
146
  self.axes,
99
- self.data_target,
147
+ self.input_targets,
100
148
  self.patch_size,
101
149
  )
102
- elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
150
+ elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
103
151
  return prepare_patches_supervised(
104
152
  self.inputs,
105
- self.data_target,
153
+ self.input_targets,
106
154
  self.axes,
107
155
  self.patch_size,
108
156
  self.read_source_func,
@@ -111,7 +159,7 @@ class InMemoryDataset(Dataset):
111
159
  raise ValueError(
112
160
  f"Data and target must be of the same type, either both numpy "
113
161
  f"arrays or both lists of paths, got {type(self.inputs)} (data) "
114
- f"and {type(self.data_target)} (target)."
162
+ f"and {type(self.input_targets)} (target)."
115
163
  )
116
164
  else:
117
165
  if isinstance(self.inputs, np.ndarray):
@@ -137,9 +185,9 @@ class InMemoryDataset(Dataset):
137
185
  int
138
186
  Length of the dataset.
139
187
  """
140
- return len(self.data)
188
+ return self.data.shape[0]
141
189
 
142
- def __getitem__(self, index: int) -> Tuple[np.ndarray]:
190
+ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
143
191
  """
144
192
  Return the patch corresponding to the provided index.
145
193
 
@@ -150,7 +198,7 @@ class InMemoryDataset(Dataset):
150
198
 
151
199
  Returns
152
200
  -------
153
- Tuple[np.ndarray]
201
+ tuple of numpy.ndarray
154
202
  Patch.
155
203
 
156
204
  Raises
@@ -161,13 +209,13 @@ class InMemoryDataset(Dataset):
161
209
  patch = self.data[index]
162
210
 
163
211
  # if there is a target
164
- if self.data_target is not None:
212
+ if self.data_targets is not None:
165
213
  # get target
166
214
  target = self.data_targets[index]
167
215
 
168
216
  return self.patch_transform(patch=patch, target=target)
169
217
 
170
- elif self.data_config.has_n2v_manipulate():
218
+ elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
171
219
  return self.patch_transform(patch=patch)
172
220
  else:
173
221
  raise ValueError(
@@ -193,7 +241,7 @@ class InMemoryDataset(Dataset):
193
241
 
194
242
  Returns
195
243
  -------
196
- InMemoryDataset
244
+ CAREamics InMemoryDataset
197
245
  New dataset with the extracted patches.
198
246
 
199
247
  Raises
@@ -244,117 +292,3 @@ class InMemoryDataset(Dataset):
244
292
  dataset.data_targets = val_targets
245
293
 
246
294
  return dataset
247
-
248
-
249
- class InMemoryPredictionDataset(Dataset):
250
- """
251
- Dataset storing data in memory and allowing generating patches from it.
252
-
253
- # TODO
254
- """
255
-
256
- def __init__(
257
- self,
258
- prediction_config: InferenceConfig,
259
- inputs: np.ndarray,
260
- data_target: Optional[np.ndarray] = None,
261
- read_source_func: Optional[Callable] = read_tiff,
262
- ) -> None:
263
- """Constructor.
264
-
265
- Parameters
266
- ----------
267
- array : np.ndarray
268
- Array containing the data.
269
- axes : str
270
- Description of axes in format STCZYX.
271
-
272
- Raises
273
- ------
274
- ValueError
275
- If data_path is not a directory.
276
- """
277
- self.pred_config = prediction_config
278
- self.input_array = inputs
279
- self.axes = self.pred_config.axes
280
- self.tile_size = self.pred_config.tile_size
281
- self.tile_overlap = self.pred_config.tile_overlap
282
- self.mean = self.pred_config.mean
283
- self.std = self.pred_config.std
284
- self.data_target = data_target
285
-
286
- # tiling only if both tile size and overlap are provided
287
- self.tiling = self.tile_size is not None and self.tile_overlap is not None
288
-
289
- # read function
290
- self.read_source_func = read_source_func
291
-
292
- # Generate patches
293
- self.data = self._prepare_tiles()
294
- self.mean, self.std = self.pred_config.mean, self.pred_config.std
295
-
296
- # get transforms
297
- self.patch_transform = Compose(
298
- transform_list=self.pred_config.transforms,
299
- )
300
-
301
- def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
302
- """
303
- Iterate over data source and create an array of patches.
304
-
305
- Returns
306
- -------
307
- List[XArrayTile]
308
- List of tiles.
309
- """
310
- # reshape array
311
- reshaped_sample = reshape_array(self.input_array, self.axes)
312
-
313
- if self.tiling:
314
- # generate patches, which returns a generator
315
- patch_generator = extract_tiles(
316
- arr=reshaped_sample,
317
- tile_size=self.tile_size,
318
- overlaps=self.tile_overlap,
319
- )
320
- patches_list = list(patch_generator)
321
-
322
- if len(patches_list) == 0:
323
- raise ValueError("No tiles generated, ")
324
-
325
- return patches_list
326
- else:
327
- array_shape = reshaped_sample.squeeze().shape
328
- return [(reshaped_sample, TileInformation(array_shape=array_shape))]
329
-
330
- def __len__(self) -> int:
331
- """
332
- Return the length of the dataset.
333
-
334
- Returns
335
- -------
336
- int
337
- Length of the dataset.
338
- """
339
- return len(self.data)
340
-
341
- def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
342
- """
343
- Return the patch corresponding to the provided index.
344
-
345
- Parameters
346
- ----------
347
- index : int
348
- Index of the patch to return.
349
-
350
- Returns
351
- -------
352
- Tuple[np.ndarray, TileInformation]
353
- Transformed patch.
354
- """
355
- tile_array, tile_info = self.data[index]
356
-
357
- # Apply transforms
358
- transformed_tile, _ = self.patch_transform(patch=tile_array)
359
-
360
- return transformed_tile, tile_info
@@ -0,0 +1,88 @@
1
+ """In-memory prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.transformations import NormalizeModel
12
+ from .dataset_utils import reshape_array
13
+
14
+
15
+ class InMemoryPredDataset(Dataset):
16
+ """Simple prediction dataset returning images along the sample axis.
17
+
18
+ Parameters
19
+ ----------
20
+ prediction_config : InferenceConfig
21
+ Prediction configuration.
22
+ inputs : NDArray
23
+ Input data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ prediction_config: InferenceConfig,
29
+ inputs: NDArray,
30
+ ) -> None:
31
+ """Constructor.
32
+
33
+ Parameters
34
+ ----------
35
+ prediction_config : InferenceConfig
36
+ Prediction configuration.
37
+ inputs : NDArray
38
+ Input data.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If data_path is not a directory.
44
+ """
45
+ self.pred_config = prediction_config
46
+ self.input_array = inputs
47
+ self.axes = self.pred_config.axes
48
+ self.image_means = self.pred_config.image_means
49
+ self.image_stds = self.pred_config.image_stds
50
+
51
+ # Reshape data
52
+ self.data = reshape_array(self.input_array, self.axes)
53
+
54
+ # get transforms
55
+ self.patch_transform = Compose(
56
+ transform_list=[
57
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
58
+ ],
59
+ )
60
+
61
+ def __len__(self) -> int:
62
+ """
63
+ Return the length of the dataset.
64
+
65
+ Returns
66
+ -------
67
+ int
68
+ Length of the dataset.
69
+ """
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, index: int) -> NDArray:
73
+ """
74
+ Return the patch corresponding to the provided index.
75
+
76
+ Parameters
77
+ ----------
78
+ index : int
79
+ Index of the patch to return.
80
+
81
+ Returns
82
+ -------
83
+ NDArray
84
+ Transformed patch.
85
+ """
86
+ transformed_patch, _ = self.patch_transform(patch=self.data[index])
87
+
88
+ return transformed_patch