careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,140 @@
1
+ """Iterable tiled prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.file_io.read import read_tiff
12
+ from careamics.transforms import Compose
13
+
14
+ from ..config import InferenceConfig
15
+ from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
17
+ from .dataset_utils import iterate_over_files
18
+ from .tiling import extract_tiles
19
+
20
+
21
+ class IterableTiledPredDataset(IterableDataset):
22
+ """Tiled prediction dataset.
23
+
24
+ Parameters
25
+ ----------
26
+ prediction_config : InferenceConfig
27
+ Inference configuration.
28
+ src_files : list of pathlib.Path
29
+ List of data files.
30
+ read_source_func : Callable, optional
31
+ Read source function for custom types, by default read_tiff.
32
+ **kwargs : Any
33
+ Additional keyword arguments, unused.
34
+
35
+ Attributes
36
+ ----------
37
+ data_path : str or pathlib.Path
38
+ Path to the data, must be a directory.
39
+ axes : str
40
+ Description of axes in format STCZYX.
41
+ mean : float, optional
42
+ Expected mean of the dataset, by default None.
43
+ std : float, optional
44
+ Expected standard deviation of the dataset, by default None.
45
+ patch_transform : Callable, optional
46
+ Patch transform callable, by default None.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ prediction_config: InferenceConfig,
52
+ src_files: list[Path],
53
+ read_source_func: Callable = read_tiff,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ """Constructor.
57
+
58
+ Parameters
59
+ ----------
60
+ prediction_config : InferenceConfig
61
+ Inference configuration.
62
+ src_files : List[Path]
63
+ List of data files.
64
+ read_source_func : Callable, optional
65
+ Read source function for custom types, by default read_tiff.
66
+ **kwargs : Any
67
+ Additional keyword arguments, unused.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If mean and std are not provided in the inference configuration.
73
+ """
74
+ if (
75
+ prediction_config.tile_size is None
76
+ or prediction_config.tile_overlap is None
77
+ ):
78
+ raise ValueError(
79
+ "Tile size and overlap must be provided for tiled prediction."
80
+ )
81
+
82
+ self.prediction_config = prediction_config
83
+ self.data_files = src_files
84
+ self.axes = prediction_config.axes
85
+ self.tile_size = prediction_config.tile_size
86
+ self.tile_overlap = prediction_config.tile_overlap
87
+ self.read_source_func = read_source_func
88
+
89
+ # check mean and std and create normalize transform
90
+ if (
91
+ self.prediction_config.image_means is None
92
+ or self.prediction_config.image_stds is None
93
+ ):
94
+ raise ValueError("Mean and std must be provided for prediction.")
95
+ else:
96
+ self.image_means = self.prediction_config.image_means
97
+ self.image_stds = self.prediction_config.image_stds
98
+
99
+ # instantiate normalize transform
100
+ self.patch_transform = Compose(
101
+ transform_list=[
102
+ NormalizeModel(
103
+ image_means=self.image_means,
104
+ image_stds=self.image_stds,
105
+ )
106
+ ],
107
+ )
108
+
109
+ def __iter__(
110
+ self,
111
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
112
+ """
113
+ Iterate over data source and yield single patch.
114
+
115
+ Yields
116
+ ------
117
+ Generator of NDArray and TileInformation tuple
118
+ Generator of single tiles.
119
+ """
120
+ assert (
121
+ self.image_means is not None and self.image_stds is not None
122
+ ), "Mean and std must be provided"
123
+
124
+ for sample, _ in iterate_over_files(
125
+ self.prediction_config,
126
+ self.data_files,
127
+ read_source_func=self.read_source_func,
128
+ ):
129
+ # generate patches, return a generator of single tiles
130
+ patch_gen = extract_tiles(
131
+ arr=sample,
132
+ tile_size=self.tile_size,
133
+ overlaps=self.tile_overlap,
134
+ )
135
+
136
+ # apply transform to patches
137
+ for patch_array, tile_info in patch_gen:
138
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
139
+
140
+ yield transformed_patch, tile_info
@@ -1,37 +1,83 @@
1
1
  """Patching functions."""
2
2
 
3
+ from dataclasses import dataclass
3
4
  from pathlib import Path
4
- from typing import Callable, List, Tuple, Union
5
+ from typing import Callable, Union
5
6
 
6
7
  import numpy as np
8
+ from numpy.typing import NDArray
7
9
 
8
10
  from ...utils.logging import get_logger
9
11
  from ..dataset_utils import reshape_array
12
+ from ..dataset_utils.running_stats import compute_normalization_stats
10
13
  from .sequential_patching import extract_patches_sequential
11
14
 
12
15
  logger = get_logger(__name__)
13
16
 
14
17
 
18
+ @dataclass
19
+ class Stats:
20
+ """Dataclass to store statistics."""
21
+
22
+ means: Union[NDArray, tuple, list, None]
23
+ """Mean of the data across channels."""
24
+
25
+ stds: Union[NDArray, tuple, list, None]
26
+ """Standard deviation of the data across channels."""
27
+
28
+ def get_statistics(self) -> tuple[list[float], list[float]]:
29
+ """Return the means and standard deviations.
30
+
31
+ Returns
32
+ -------
33
+ tuple of two lists of floats
34
+ Means and standard deviations.
35
+ """
36
+ if self.means is None or self.stds is None:
37
+ return [], []
38
+
39
+ return list(self.means), list(self.stds)
40
+
41
+
42
+ @dataclass
43
+ class PatchedOutput:
44
+ """Dataclass to store patches and statistics."""
45
+
46
+ patches: Union[NDArray]
47
+ """Image patches."""
48
+
49
+ targets: Union[NDArray, None]
50
+ """Target patches."""
51
+
52
+ image_stats: Stats
53
+ """Statistics of the image patches."""
54
+
55
+ target_stats: Stats
56
+ """Statistics of the target patches."""
57
+
58
+
15
59
  # called by in memory dataset
16
60
  def prepare_patches_supervised(
17
- train_files: List[Path],
18
- target_files: List[Path],
61
+ train_files: list[Path],
62
+ target_files: list[Path],
19
63
  axes: str,
20
- patch_size: Union[List[int], Tuple[int, ...]],
64
+ patch_size: Union[list[int], tuple[int, ...]],
21
65
  read_source_func: Callable,
22
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
66
+ ) -> PatchedOutput:
23
67
  """
24
68
  Iterate over data source and create an array of patches and corresponding targets.
25
69
 
70
+ The lists of Paths should be pre-sorted.
71
+
26
72
  Parameters
27
73
  ----------
28
- train_files : List[Path]
74
+ train_files : list of pathlib.Path
29
75
  List of paths to training data.
30
- target_files : List[Path]
76
+ target_files : list of pathlib.Path
31
77
  List of paths to target data.
32
78
  axes : str
33
79
  Axes of the data.
34
- patch_size : Union[List[int], Tuple[int]]
80
+ patch_size : list or tuple of int
35
81
  Size of the patches.
36
82
  read_source_func : Callable
37
83
  Function to read the data.
@@ -41,9 +87,6 @@ def prepare_patches_supervised(
41
87
  np.ndarray
42
88
  Array of patches.
43
89
  """
44
- train_files.sort()
45
- target_files.sort()
46
-
47
90
  means, stds, num_samples = 0, 0, 0
48
91
  all_patches, all_targets = [], []
49
92
  for train_filename, target_filename in zip(train_files, target_files):
@@ -83,46 +126,47 @@ def prepare_patches_supervised(
83
126
  f"{target_files}."
84
127
  )
85
128
 
86
- result_mean, result_std = means / num_samples, stds / num_samples
129
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
130
+ target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
87
131
 
88
132
  patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
89
133
  target_array: np.ndarray = np.concatenate(all_targets, axis=0)
90
134
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
91
135
 
92
- return (
136
+ return PatchedOutput(
93
137
  patch_array,
94
138
  target_array,
95
- result_mean,
96
- result_std,
139
+ Stats(image_means, image_stds),
140
+ Stats(target_means, target_stds),
97
141
  )
98
142
 
99
143
 
100
144
  # called by in_memory_dataset
101
145
  def prepare_patches_unsupervised(
102
- train_files: List[Path],
146
+ train_files: list[Path],
103
147
  axes: str,
104
- patch_size: Union[List[int], Tuple[int]],
148
+ patch_size: Union[list[int], tuple[int]],
105
149
  read_source_func: Callable,
106
- ) -> Tuple[np.ndarray, None, float, float]:
150
+ ) -> PatchedOutput:
107
151
  """Iterate over data source and create an array of patches.
108
152
 
109
153
  This method returns the mean and standard deviation of the image.
110
154
 
111
155
  Parameters
112
156
  ----------
113
- train_files : List[Path]
157
+ train_files : list of pathlib.Path
114
158
  List of paths to training data.
115
159
  axes : str
116
160
  Axes of the data.
117
- patch_size : Union[List[int], Tuple[int]]
161
+ patch_size : list or tuple of int
118
162
  Size of the patches.
119
163
  read_source_func : Callable
120
164
  Function to read the data.
121
165
 
122
166
  Returns
123
167
  -------
124
- Tuple[np.ndarray, None, float, float]
125
- Source and target patches, mean and standard deviation.
168
+ PatchedOutput
169
+ Dataclass holding patches and their statistics.
126
170
  """
127
171
  means, stds, num_samples = 0, 0, 0
128
172
  all_patches = []
@@ -149,21 +193,23 @@ def prepare_patches_unsupervised(
149
193
  if num_samples == 0:
150
194
  raise ValueError(f"No valid samples found in the input data: {train_files}.")
151
195
 
152
- result_mean, result_std = means / num_samples, stds / num_samples
196
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
153
197
 
154
198
  patch_array: np.ndarray = np.concatenate(all_patches)
155
199
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
156
200
 
157
- return patch_array, _, result_mean, result_std # TODO return object?
201
+ return PatchedOutput(
202
+ patch_array, None, Stats(image_means, image_stds), Stats((), ())
203
+ )
158
204
 
159
205
 
160
206
  # called on arrays by in memory dataset
161
207
  def prepare_patches_supervised_array(
162
- data: np.ndarray,
208
+ data: NDArray,
163
209
  axes: str,
164
- data_target: np.ndarray,
165
- patch_size: Union[List[int], Tuple[int]],
166
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
210
+ data_target: NDArray,
211
+ patch_size: Union[list[int], tuple[int]],
212
+ ) -> PatchedOutput:
167
213
  """Iterate over data source and create an array of patches.
168
214
 
169
215
  This method expects an array of shape SC(Z)YX, where S and C can be singleton
@@ -173,28 +219,28 @@ def prepare_patches_supervised_array(
173
219
 
174
220
  Parameters
175
221
  ----------
176
- data : np.ndarray
222
+ data : numpy.ndarray
177
223
  Input data array.
178
224
  axes : str
179
225
  Axes of the data.
180
- data_target : np.ndarray
226
+ data_target : numpy.ndarray
181
227
  Target data array.
182
- patch_size : Union[List[int], Tuple[int]]
228
+ patch_size : list or tuple of int
183
229
  Size of the patches.
184
230
 
185
231
  Returns
186
232
  -------
187
- Tuple[np.ndarray, np.ndarray, float, float]
188
- Source and target patches, mean and standard deviation.
233
+ PatchedOutput
234
+ Dataclass holding the source and target patches, with their statistics.
189
235
  """
190
- # compute statistics
191
- mean = data.mean()
192
- std = data.std()
193
-
194
236
  # reshape array
195
237
  reshaped_sample = reshape_array(data, axes)
196
238
  reshaped_target = reshape_array(data_target, axes)
197
239
 
240
+ # compute statistics
241
+ image_means, image_stds = compute_normalization_stats(reshaped_sample)
242
+ target_means, target_stds = compute_normalization_stats(reshaped_target)
243
+
198
244
  # generate patches, return a generator
199
245
  patches, patch_targets = extract_patches_sequential(
200
246
  reshaped_sample, patch_size=patch_size, target=reshaped_target
@@ -205,20 +251,20 @@ def prepare_patches_supervised_array(
205
251
 
206
252
  logger.info(f"Extracted {patches.shape[0]} patches from input array.")
207
253
 
208
- return (
254
+ return PatchedOutput(
209
255
  patches,
210
256
  patch_targets,
211
- mean,
212
- std,
257
+ Stats(image_means, image_stds),
258
+ Stats(target_means, target_stds),
213
259
  )
214
260
 
215
261
 
216
262
  # called by in memory dataset
217
263
  def prepare_patches_unsupervised_array(
218
- data: np.ndarray,
264
+ data: NDArray,
219
265
  axes: str,
220
- patch_size: Union[List[int], Tuple[int]],
221
- ) -> Tuple[np.ndarray, None, float, float]:
266
+ patch_size: Union[list[int], tuple[int]],
267
+ ) -> PatchedOutput:
222
268
  """
223
269
  Iterate over data source and create an array of patches.
224
270
 
@@ -229,26 +275,25 @@ def prepare_patches_unsupervised_array(
229
275
 
230
276
  Parameters
231
277
  ----------
232
- data : np.ndarray
278
+ data : numpy.ndarray
233
279
  Input data array.
234
280
  axes : str
235
281
  Axes of the data.
236
- patch_size : Union[List[int], Tuple[int]]
282
+ patch_size : list or tuple of int
237
283
  Size of the patches.
238
284
 
239
285
  Returns
240
286
  -------
241
- Tuple[np.ndarray, None, float, float]
242
- Source patches, mean and standard deviation.
287
+ PatchedOutput
288
+ Dataclass holding the patches and their statistics.
243
289
  """
244
- # calculate mean and std
245
- mean = data.mean()
246
- std = data.std()
247
-
248
290
  # reshape array
249
291
  reshaped_sample = reshape_array(data, axes)
250
292
 
293
+ # calculate mean and std
294
+ means, stds = compute_normalization_stats(reshaped_sample)
295
+
251
296
  # generate patches, return a generator
252
297
  patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
253
298
 
254
- return patches, _, mean, std # TODO inelegant, replace by dataclass?
299
+ return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
@@ -13,6 +13,7 @@ def extract_patches_random(
13
13
  arr: np.ndarray,
14
14
  patch_size: Union[List[int], Tuple[int, ...]],
15
15
  target: Optional[np.ndarray] = None,
16
+ seed: Optional[int] = None,
16
17
  ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
17
18
  """
18
19
  Generate patches from an array in a random manner.
@@ -34,12 +35,16 @@ def extract_patches_random(
34
35
  Patch sizes in each dimension.
35
36
  target : Optional[np.ndarray], optional
36
37
  Target array, by default None.
38
+ seed : Optional[int], optional
39
+ Random seed, by default None.
37
40
 
38
41
  Yields
39
42
  ------
40
43
  Generator[np.ndarray, None, None]
41
44
  Generator of patches.
42
45
  """
46
+ rng = np.random.default_rng(seed=seed)
47
+
43
48
  is_3d_patch = len(patch_size) == 3
44
49
 
45
50
  # patches sanity check
@@ -48,9 +53,6 @@ def extract_patches_random(
48
53
  # Update patch size to encompass S and C dimensions
49
54
  patch_size = [1, arr.shape[1], *patch_size]
50
55
 
51
- # random generator
52
- rng = np.random.default_rng()
53
-
54
56
  # iterate over the number of samples (S or T)
55
57
  for sample_idx in range(arr.shape[0]):
56
58
  # get sample array
@@ -113,6 +115,7 @@ def extract_patches_random_from_chunks(
113
115
  patch_size: Union[List[int], Tuple[int, ...]],
114
116
  chunk_size: Union[List[int], Tuple[int, ...]],
115
117
  chunk_limit: Optional[int] = None,
118
+ seed: Optional[int] = None,
116
119
  ) -> Generator[np.ndarray, None, None]:
117
120
  """
118
121
  Generate patches from an array in a random manner.
@@ -130,6 +133,8 @@ def extract_patches_random_from_chunks(
130
133
  Chunk sizes to load from the.
131
134
  chunk_limit : Optional[int], optional
132
135
  Number of chunks to load, by default None.
136
+ seed : Optional[int], optional
137
+ Random seed, by default None.
133
138
 
134
139
  Yields
135
140
  ------
@@ -141,7 +146,7 @@ def extract_patches_random_from_chunks(
141
146
  # Patches sanity check
142
147
  validate_patch_dimensions(arr, patch_size, is_3d_patch)
143
148
 
144
- rng = np.random.default_rng()
149
+ rng = np.random.default_rng(seed=seed)
145
150
  num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
146
151
 
147
152
  # Iterate over num chunks in the array
@@ -45,18 +45,20 @@ def validate_patch_dimensions(
45
45
  if len(patch_size) != len(arr.shape[2:]):
46
46
  raise ValueError(
47
47
  f"There must be a patch size for each spatial dimensions "
48
- f"(got {patch_size} patches for dims {arr.shape})."
48
+ f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
49
49
  )
50
50
 
51
51
  # Sanity checks on patch sizes versus array dimension
52
52
  if is_3d_patch and patch_size[0] > arr.shape[-3]:
53
53
  raise ValueError(
54
54
  f"Z patch size is inconsistent with image shape "
55
- f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
55
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
56
+ f"order."
56
57
  )
57
58
 
58
59
  if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
59
60
  raise ValueError(
60
61
  f"At least one of YX patch dimensions is larger than the corresponding "
61
- f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
62
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
63
+ f"Check the axes order."
62
64
  )
@@ -0,0 +1,10 @@
1
+ """Tiling functions."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "extract_tiles",
6
+ "collate_tiles",
7
+ ]
8
+
9
+ from .collate_tiles import collate_tiles
10
+ from .tiled_patching import extract_tiles
@@ -0,0 +1,33 @@
1
+ """Collate function for tiling."""
2
+
3
+ from typing import Any, List, Tuple
4
+
5
+ import numpy as np
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from careamics.config.tile_information import TileInformation
9
+
10
+
11
+ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
12
+ """
13
+ Collate tiles received from CAREamics prediction dataloader.
14
+
15
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
16
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
17
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
18
+ stitch coordinates.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
23
+ Batch of tiles.
24
+
25
+ Returns
26
+ -------
27
+ Any
28
+ Collated batch.
29
+ """
30
+ new_batch = [tile for tile, _ in batch]
31
+ tiles_batch = [tile_info for _, tile_info in batch]
32
+
33
+ return default_collate(new_batch), tiles_batch
@@ -84,15 +84,15 @@ def extract_tiles(
84
84
  tile_size: Union[List[int], Tuple[int, ...]],
85
85
  overlaps: Union[List[int], Tuple[int, ...]],
86
86
  ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
87
- """
88
- Generate tiles from the input array with specified overlap.
87
+ """Generate tiles from the input array with specified overlap.
89
88
 
90
89
  The tiles cover the whole array. The method returns a generator that yields
91
90
  tuples of array and tile information, the latter includes whether
92
91
  the tile is the last one, the coordinates of the overlap crop, and the coordinates
93
92
  of the stitched tile.
94
93
 
95
- The array has shape C(Z)YX, where C can be a singleton.
94
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
95
+ where C can be a singleton.
96
96
 
97
97
  Parameters
98
98
  ----------
@@ -155,10 +155,10 @@ def extract_tiles(
155
155
  # create tile information
156
156
  tile_info = TileInformation(
157
157
  array_shape=sample.squeeze().shape,
158
- tiled=True,
159
158
  last_tile=last_tile,
160
159
  overlap_crop_coords=overlap_crop_coords,
161
160
  stitch_coords=stitch_coords,
161
+ sample_id=sample_idx,
162
162
  )
163
163
 
164
164
  yield tile, tile_info
@@ -0,0 +1,7 @@
1
+ """Functions relating reading and writing image files."""
2
+
3
+ __all__ = ["read", "write", "get_read_func", "get_write_func"]
4
+
5
+ from . import read, write
6
+ from .read import get_read_func
7
+ from .write import get_write_func
@@ -0,0 +1,11 @@
1
+ """Functions relating to reading image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_read_func",
5
+ "read_tiff",
6
+ "read_zarr",
7
+ ]
8
+
9
+ from .get_func import get_read_func
10
+ from .tiff import read_tiff
11
+ from .zarr import read_zarr
@@ -0,0 +1,56 @@
1
+ """Module to get read functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, Protocol, Union
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import read_tiff
11
+
12
+
13
+ # This is very strict, function signature has to match including arg names
14
+ # See WriteFunc notes
15
+ class ReadFunc(Protocol):
16
+ """Protocol for type hinting read functions."""
17
+
18
+ def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
19
+ """
20
+ Type hinted callables must match this function signature (not including self).
21
+
22
+ Parameters
23
+ ----------
24
+ file_path : pathlib.Path
25
+ Path to file.
26
+ *args
27
+ Other positional arguments.
28
+ **kwargs
29
+ Other keyword arguments.
30
+ """
31
+
32
+
33
+ READ_FUNCS: Dict[SupportedData, ReadFunc] = {
34
+ SupportedData.TIFF: read_tiff,
35
+ }
36
+
37
+
38
+ def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
39
+ """
40
+ Get the read function for the data type.
41
+
42
+ Parameters
43
+ ----------
44
+ data_type : SupportedData
45
+ Data type.
46
+
47
+ Returns
48
+ -------
49
+ callable
50
+ Read function.
51
+ """
52
+ if data_type in READ_FUNCS:
53
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
54
+ return READ_FUNCS[data_type]
55
+ else:
56
+ raise NotImplementedError(f"Data type '{data_type}' is not supported.")