careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,416 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+ from ..config import DataModel, InferenceModel
11
+ from ..config.tile_information import TileInformation
12
+ from ..utils.logging import get_logger
13
+ from .dataset_utils import read_tiff, reshape_array
14
+ from .patching import (
15
+ get_patch_transform,
16
+ )
17
+ from .patching.random_patching import extract_patches_random
18
+ from .patching.tiled_patching import extract_tiles
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class PathIterableDataset(IterableDataset):
24
+ """
25
+ Dataset allowing extracting patches w/o loading whole data into memory.
26
+
27
+ Parameters
28
+ ----------
29
+ data_path : Union[str, Path]
30
+ Path to the data, must be a directory.
31
+ axes : str
32
+ Description of axes in format STCZYX.
33
+ patch_extraction_method : Union[ExtractionStrategies, None]
34
+ Patch extraction strategy, as defined in extraction_strategy.
35
+ patch_size : Optional[Union[List[int], Tuple[int]]], optional
36
+ Size of the patches in each dimension, by default None.
37
+ patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
38
+ Overlap of the patches in each dimension, by default None.
39
+ mean : Optional[float], optional
40
+ Expected mean of the dataset, by default None.
41
+ std : Optional[float], optional
42
+ Expected standard deviation of the dataset, by default None.
43
+ patch_transform : Optional[Callable], optional
44
+ Patch transform callable, by default None.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ data_config: Union[DataModel, InferenceModel],
50
+ src_files: List[Path],
51
+ target_files: Optional[List[Path]] = None,
52
+ read_source_func: Callable = read_tiff,
53
+ ) -> None:
54
+ self.data_config = data_config
55
+ self.data_files = src_files
56
+ self.target_files = target_files
57
+ self.data_config = data_config
58
+ self.read_source_func = read_source_func
59
+
60
+ # compute mean and std over the dataset
61
+ if not data_config.mean or not data_config.std:
62
+ self.mean, self.std = self._calculate_mean_and_std()
63
+
64
+ # if the transforms are not an instance of Compose
65
+ # Check if the data_config is an instance of DataModel or InferenceModel
66
+ # isinstance isn't working properly here
67
+ if hasattr(data_config, "has_transform_list"):
68
+ if data_config.has_transform_list():
69
+ # update mean and std in configuration
70
+ # the object is mutable and should then be recorded in the CAREamist
71
+ data_config.set_mean_and_std(self.mean, self.std)
72
+ else:
73
+ data_config.set_mean_and_std(self.mean, self.std)
74
+
75
+ else:
76
+ self.mean = data_config.mean
77
+ self.std = data_config.std
78
+
79
+ # get transforms
80
+ self.patch_transform = get_patch_transform(
81
+ patch_transforms=data_config.transforms,
82
+ with_target=target_files is not None,
83
+ )
84
+
85
+ def _calculate_mean_and_std(self) -> Tuple[float, float]:
86
+ """
87
+ Calculate mean and std of the dataset.
88
+
89
+ Returns
90
+ -------
91
+ Tuple[float, float]
92
+ Tuple containing mean and standard deviation.
93
+ """
94
+ means, stds = 0, 0
95
+ num_samples = 0
96
+
97
+ for sample, _ in self._iterate_over_files():
98
+ means += sample.mean()
99
+ stds += sample.std()
100
+ num_samples += 1
101
+
102
+ if num_samples == 0:
103
+ raise ValueError("No samples found in the dataset.")
104
+
105
+ result_mean = means / num_samples
106
+ result_std = stds / num_samples
107
+
108
+ logger.info(f"Calculated mean and std for {num_samples} images")
109
+ logger.info(f"Mean: {result_mean}, std: {result_std}")
110
+ return result_mean, result_std
111
+
112
+ def _iterate_over_files(
113
+ self,
114
+ ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
115
+ """
116
+ Iterate over data source and yield whole image.
117
+
118
+ Yields
119
+ ------
120
+ np.ndarray
121
+ Image.
122
+ """
123
+ # When num_workers > 0, each worker process will have a different copy of the
124
+ # dataset object
125
+ # Configuring each copy independently to avoid having duplicate data returned
126
+ # from the workers
127
+ worker_info = get_worker_info()
128
+ worker_id = worker_info.id if worker_info is not None else 0
129
+ num_workers = worker_info.num_workers if worker_info is not None else 1
130
+
131
+ # iterate over the files
132
+ for i, filename in enumerate(self.data_files):
133
+ # retrieve file corresponding to the worker id
134
+ if i % num_workers == worker_id:
135
+ try:
136
+ # read data
137
+ sample = self.read_source_func(filename, self.data_config.axes)
138
+
139
+ # read target, if available
140
+ if self.target_files is not None:
141
+ if filename.name != self.target_files[i].name:
142
+ raise ValueError(
143
+ f"File {filename} does not match target file "
144
+ f"{self.target_files[i]}. Have you passed sorted "
145
+ f"arrays?"
146
+ )
147
+
148
+ # read target
149
+ target = self.read_source_func(
150
+ self.target_files[i], self.data_config.axes
151
+ )
152
+
153
+ yield sample, target
154
+ else:
155
+ yield sample, None
156
+
157
+ except Exception as e:
158
+ logger.error(f"Error reading file {filename}: {e}")
159
+
160
+ def __iter__(
161
+ self,
162
+ ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]:
163
+ """
164
+ Iterate over data source and yield single patch.
165
+
166
+ Yields
167
+ ------
168
+ np.ndarray
169
+ Single patch.
170
+ """
171
+ assert (
172
+ self.mean is not None and self.std is not None
173
+ ), "Mean and std must be provided"
174
+
175
+ # iterate over files
176
+ for sample_input, sample_target in self._iterate_over_files():
177
+ reshaped_sample = reshape_array(sample_input, self.data_config.axes)
178
+ reshaped_target = (
179
+ None
180
+ if sample_target is None
181
+ else reshape_array(sample_target, self.data_config.axes)
182
+ )
183
+
184
+ patches = extract_patches_random(
185
+ arr=reshaped_sample,
186
+ patch_size=self.data_config.patch_size,
187
+ target=reshaped_target,
188
+ )
189
+
190
+ # iterate over patches
191
+ # patches are tuples of (patch, target) if target is available
192
+ # or (patch, None) only if no target is available
193
+ # patch is of dimensions (C)ZYX
194
+ for patch_data in patches:
195
+ # if there is a target
196
+ if self.target_files is not None:
197
+ # Albumentations expects the channel dimension to be last
198
+ # Taking the first element because patch_data can include target
199
+ c_patch = np.moveaxis(patch_data[0], 0, -1)
200
+ c_target = np.moveaxis(patch_data[1], 0, -1)
201
+
202
+ # apply the transform to the patch and the target
203
+ transformed = self.patch_transform(
204
+ image=c_patch,
205
+ target=c_target,
206
+ )
207
+
208
+ # move the axes back to the original position
209
+ c_patch = np.moveaxis(transformed["image"], -1, 0)
210
+ c_target = np.moveaxis(transformed["target"], -1, 0)
211
+
212
+ yield (c_patch, c_target)
213
+ elif self.data_config.has_n2v_manipulate():
214
+ # Albumentations expects the channel dimension to be last
215
+ # Taking the first element because patch_data can include target
216
+ patch = np.moveaxis(patch_data[0], 0, -1)
217
+
218
+ # apply transform
219
+ transformed = self.patch_transform(image=patch)
220
+
221
+ # retrieve the output of ManipulateN2V
222
+ results = transformed["image"]
223
+ masked_patch: np.ndarray = results[0]
224
+ original_patch: np.ndarray = results[1]
225
+ mask: np.ndarray = results[2]
226
+
227
+ # move C axes back
228
+ masked_patch = np.moveaxis(masked_patch, -1, 0)
229
+ original_patch = np.moveaxis(original_patch, -1, 0)
230
+ mask = np.moveaxis(mask, -1, 0)
231
+
232
+ yield (masked_patch, original_patch, mask)
233
+ else:
234
+ raise ValueError(
235
+ "Something went wrong! Not target file (no supervised "
236
+ "training) and no N2V transform (no n2v training either)."
237
+ )
238
+
239
+ def get_number_of_files(self) -> int:
240
+ """
241
+ Return the number of files in the dataset.
242
+
243
+ Returns
244
+ -------
245
+ int
246
+ Number of files in the dataset.
247
+ """
248
+ return len(self.data_files)
249
+
250
+ def split_dataset(
251
+ self,
252
+ percentage: float = 0.1,
253
+ minimum_number: int = 5,
254
+ ) -> PathIterableDataset:
255
+ """Split up dataset in two.
256
+
257
+ Splits the datest sing a percentage of the data (files) to extract, or the
258
+ minimum number of the percentage is less than the minimum number.
259
+
260
+ Parameters
261
+ ----------
262
+ percentage : float, optional
263
+ Percentage of files to split up, by default 0.1
264
+ minimum_number : int, optional
265
+ Minimum number of files to split up, by default 5
266
+
267
+ Returns
268
+ -------
269
+ IterableDataset
270
+ Dataset containing the split data.
271
+
272
+ Raises
273
+ ------
274
+ ValueError
275
+ If the percentage is smaller than 0 or larger than 1.
276
+ ValueError
277
+ If the minimum number is smaller than 1 or larger than the number of files.
278
+ """
279
+ if percentage < 0 or percentage > 1:
280
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
281
+
282
+ if minimum_number < 1 or minimum_number > self.get_number_of_files():
283
+ raise ValueError(
284
+ f"Minimum number of files must be between 1 and "
285
+ f"{self.get_number_of_files()} (number of files), got "
286
+ f"{minimum_number}."
287
+ )
288
+
289
+ # compute number of files
290
+ total_files = self.get_number_of_files()
291
+ n_files = max(round(percentage * total_files), minimum_number)
292
+
293
+ # get random indices
294
+ indices = np.random.choice(total_files, n_files, replace=False)
295
+
296
+ # extract files
297
+ val_files = [self.data_files[i] for i in indices]
298
+
299
+ # remove patches from self.patch
300
+ data_files = []
301
+ for i, file in enumerate(self.data_files):
302
+ if i not in indices:
303
+ data_files.append(file)
304
+ self.data_files = data_files
305
+
306
+ # same for targets
307
+ if self.target_files is not None:
308
+ val_target_files = [self.target_files[i] for i in indices]
309
+
310
+ data_target_files = []
311
+ for i, file in enumerate(self.target_files):
312
+ if i not in indices:
313
+ data_target_files.append(file)
314
+ self.target_files = data_target_files
315
+
316
+ # clone the dataset
317
+ dataset = copy.deepcopy(self)
318
+
319
+ # reassign patches
320
+ dataset.data_files = val_files
321
+
322
+ # reassign targets
323
+ if self.target_files is not None:
324
+ dataset.target_files = val_target_files
325
+
326
+ return dataset
327
+
328
+
329
+ class IterablePredictionDataset(PathIterableDataset):
330
+ """
331
+ Dataset allowing extracting patches w/o loading whole data into memory.
332
+
333
+ Parameters
334
+ ----------
335
+ data_path : Union[str, Path]
336
+ Path to the data, must be a directory.
337
+ axes : str
338
+ Description of axes in format STCZYX.
339
+ mean : Optional[float], optional
340
+ Expected mean of the dataset, by default None.
341
+ std : Optional[float], optional
342
+ Expected standard deviation of the dataset, by default None.
343
+ patch_transform : Optional[Callable], optional
344
+ Patch transform callable, by default None.
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ prediction_config: InferenceModel,
350
+ src_files: List[Path],
351
+ read_source_func: Callable = read_tiff,
352
+ **kwargs: Any,
353
+ ) -> None:
354
+ super().__init__(
355
+ data_config=prediction_config,
356
+ src_files=src_files,
357
+ read_source_func=read_source_func,
358
+ )
359
+
360
+ self.prediction_config = prediction_config
361
+ self.axes = prediction_config.axes
362
+ self.tile_size = self.prediction_config.tile_size
363
+ self.tile_overlap = self.prediction_config.tile_overlap
364
+ self.read_source_func = read_source_func
365
+
366
+ # tile only if both tile size and overlaps are provided
367
+ self.tile = self.tile_size is not None and self.tile_overlap is not None
368
+
369
+ # get tta transforms
370
+ self.patch_transform = get_patch_transform(
371
+ patch_transforms=prediction_config.transforms,
372
+ with_target=False,
373
+ )
374
+
375
+ def __iter__(
376
+ self,
377
+ ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
378
+ """
379
+ Iterate over data source and yield single patch.
380
+
381
+ Yields
382
+ ------
383
+ np.ndarray
384
+ Single patch.
385
+ """
386
+ assert (
387
+ self.mean is not None and self.std is not None
388
+ ), "Mean and std must be provided"
389
+
390
+ for sample, _ in self._iterate_over_files():
391
+ # reshape array
392
+ reshaped_sample = reshape_array(sample, self.axes)
393
+
394
+ if self.tile:
395
+ # generate patches, return a generator
396
+ patch_gen = extract_tiles(
397
+ arr=reshaped_sample,
398
+ tile_size=self.tile_size,
399
+ overlaps=self.tile_overlap,
400
+ )
401
+ else:
402
+ # just wrap the sample in a generator with default tiling info
403
+ array_shape = reshaped_sample.squeeze().shape
404
+ patch_gen = (
405
+ (reshaped_sample, TileInformation(array_shape=array_shape))
406
+ for _ in range(1)
407
+ )
408
+
409
+ # apply transform to patches
410
+ for patch_array, tile_info in patch_gen:
411
+ # albumentations expects the channel dimension to be last
412
+ patch = np.moveaxis(patch_array, 0, -1)
413
+ transformed_patch = self.patch_transform(image=patch)
414
+ transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
415
+
416
+ yield transformed_patch, tile_info
@@ -0,0 +1,8 @@
1
+ """Patching and tiling functions."""
2
+
3
+
4
+ __all__ = [
5
+ "get_patch_transform",
6
+ ]
7
+
8
+ from .patch_transform import get_patch_transform
@@ -0,0 +1,44 @@
1
+ from typing import List, Union
2
+
3
+ import albumentations as Aug
4
+
5
+ from careamics.config.data_model import TRANSFORMS_UNION
6
+ from careamics.transforms import get_all_transforms
7
+
8
+
9
+ # TODO add some explanations on how the additional_targets is used
10
+ def get_patch_transform(
11
+ patch_transforms: Union[List[TRANSFORMS_UNION], Aug.Compose],
12
+ with_target: bool,
13
+ normalize_mask: bool = True,
14
+ ) -> Aug.Compose:
15
+ """Return a pixel manipulation function."""
16
+ # if we passed a Compose, we just return it
17
+ if isinstance(patch_transforms, Aug.Compose):
18
+ return patch_transforms
19
+
20
+ # empty list of transforms is a NoOp
21
+ elif len(patch_transforms) == 0:
22
+ return Aug.Compose(
23
+ [Aug.NoOp()],
24
+ additional_targets={}, # TODO this part need be checked (wrt segmentation)
25
+ )
26
+
27
+ # else we have a list of transforms
28
+ else:
29
+ # retrieve all transforms
30
+ all_transforms = get_all_transforms()
31
+
32
+ # instantiate all transforms
33
+ transforms = [
34
+ all_transforms[transform.name](**transform.model_dump())
35
+ for transform in patch_transforms
36
+ ]
37
+
38
+ return Aug.Compose(
39
+ transforms,
40
+ # apply image aug to "target"
41
+ additional_targets={"target": "image"}
42
+ if (with_target and normalize_mask) # TODO check this
43
+ else {},
44
+ )
@@ -0,0 +1,212 @@
1
+ """
2
+ Tiling submodule.
3
+
4
+ These functions are used to tile images into patches or tiles.
5
+ """
6
+ from pathlib import Path
7
+ from typing import Callable, List, Tuple, Union
8
+
9
+ import numpy as np
10
+
11
+ from ...utils.logging import get_logger
12
+ from ..dataset_utils import reshape_array
13
+ from .sequential_patching import extract_patches_sequential
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ # called by in memory dataset
19
+ def prepare_patches_supervised(
20
+ train_files: List[Path],
21
+ target_files: List[Path],
22
+ axes: str,
23
+ patch_size: Union[List[int], Tuple[int]],
24
+ read_source_func: Callable,
25
+ ) -> Tuple[np.ndarray, np.ndarray, float, float]:
26
+ """
27
+ Iterate over data source and create an array of patches and corresponding targets.
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray
32
+ Array of patches.
33
+ """
34
+ train_files.sort()
35
+ target_files.sort()
36
+
37
+ means, stds, num_samples = 0, 0, 0
38
+ all_patches, all_targets = [], []
39
+ for train_filename, target_filename in zip(train_files, target_files):
40
+ try:
41
+ sample: np.ndarray = read_source_func(train_filename, axes)
42
+ target: np.ndarray = read_source_func(target_filename, axes)
43
+ means += sample.mean()
44
+ stds += sample.std()
45
+ num_samples += 1
46
+
47
+ # reshape array
48
+ sample = reshape_array(sample, axes)
49
+ target = reshape_array(target, axes)
50
+
51
+ # generate patches, return a generator
52
+ patches, targets = extract_patches_sequential(
53
+ sample, patch_size=patch_size, target=target
54
+ )
55
+
56
+ # convert generator to list and add to all_patches
57
+ all_patches.append(patches)
58
+
59
+ # ensure targets are not None (type checking)
60
+ if targets is not None:
61
+ all_targets.append(targets)
62
+ else:
63
+ raise ValueError(f"No target found for {target_filename}.")
64
+
65
+ except Exception as e:
66
+ # emit warning and continue
67
+ logger.error(f"Failed to read {train_filename} or {target_filename}: {e}")
68
+
69
+ # raise error if no valid samples found
70
+ if num_samples == 0 or len(all_patches) == 0:
71
+ raise ValueError(
72
+ f"No valid samples found in the input data: {train_files} and "
73
+ f"{target_files}."
74
+ )
75
+
76
+ result_mean, result_std = means / num_samples, stds / num_samples
77
+
78
+ patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
79
+ target_array: np.ndarray = np.concatenate(all_targets, axis=0)
80
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
81
+
82
+ return (
83
+ patch_array,
84
+ target_array,
85
+ result_mean,
86
+ result_std,
87
+ )
88
+
89
+
90
+ # called by in_memory_dataset
91
+ def prepare_patches_unsupervised(
92
+ train_files: List[Path],
93
+ axes: str,
94
+ patch_size: Union[List[int], Tuple[int]],
95
+ read_source_func: Callable,
96
+ ) -> Tuple[np.ndarray, None, float, float]:
97
+ """
98
+ Iterate over data source and create an array of patches.
99
+
100
+ Returns
101
+ -------
102
+ np.ndarray
103
+ Array of patches.
104
+ """
105
+ means, stds, num_samples = 0, 0, 0
106
+ all_patches = []
107
+ for filename in train_files:
108
+ try:
109
+ sample: np.ndarray = read_source_func(filename, axes)
110
+ means += sample.mean()
111
+ stds += sample.std()
112
+ num_samples += 1
113
+
114
+ # reshape array
115
+ sample = reshape_array(sample, axes)
116
+
117
+ # generate patches, return a generator
118
+ patches, _ = extract_patches_sequential(sample, patch_size=patch_size)
119
+
120
+ # convert generator to list and add to all_patches
121
+ all_patches.append(patches)
122
+ except Exception as e:
123
+ # emit warning and continue
124
+ logger.error(f"Failed to read {filename}: {e}")
125
+
126
+ # raise error if no valid samples found
127
+ if num_samples == 0:
128
+ raise ValueError(f"No valid samples found in the input data: {train_files}.")
129
+
130
+ result_mean, result_std = means / num_samples, stds / num_samples
131
+
132
+ patch_array: np.ndarray = np.concatenate(all_patches)
133
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
134
+
135
+ return patch_array, _, result_mean, result_std # TODO return object?
136
+
137
+
138
+ # called on arrays by in memory dataset
139
+ def prepare_patches_supervised_array(
140
+ data: np.ndarray,
141
+ axes: str,
142
+ data_target: np.ndarray,
143
+ patch_size: Union[List[int], Tuple[int]],
144
+ ) -> Tuple[np.ndarray, np.ndarray, float, float]:
145
+ """Iterate over data source and create an array of patches.
146
+
147
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
148
+ dimensions.
149
+
150
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
151
+
152
+ Returns
153
+ -------
154
+ np.ndarray
155
+ Array of patches.
156
+ """
157
+ # compute statistics
158
+ mean = data.mean()
159
+ std = data.std()
160
+
161
+ # reshape array
162
+ reshaped_sample = reshape_array(data, axes)
163
+ reshaped_target = reshape_array(data_target, axes)
164
+
165
+ # generate patches, return a generator
166
+ patches, patch_targets = extract_patches_sequential(
167
+ reshaped_sample, patch_size=patch_size, target=reshaped_target
168
+ )
169
+
170
+ if patch_targets is None:
171
+ raise ValueError("No target extracted.")
172
+
173
+ logger.info(f"Extracted {patches.shape[0]} patches from input array.")
174
+
175
+ return (
176
+ patches,
177
+ patch_targets,
178
+ mean,
179
+ std,
180
+ )
181
+
182
+
183
+ # called by in memory dataset
184
+ def prepare_patches_unsupervised_array(
185
+ data: np.ndarray,
186
+ axes: str,
187
+ patch_size: Union[List[int], Tuple[int]],
188
+ ) -> Tuple[np.ndarray, None, float, float]:
189
+ """
190
+ Iterate over data source and create an array of patches.
191
+
192
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
193
+ dimensions.
194
+
195
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
196
+
197
+ Returns
198
+ -------
199
+ np.ndarray
200
+ Array of patches.
201
+ """
202
+ # calculate mean and std
203
+ mean = data.mean()
204
+ std = data.std()
205
+
206
+ # reshape array
207
+ reshaped_sample = reshape_array(data, axes)
208
+
209
+ # generate patches, return a generator
210
+ patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
211
+
212
+ return patches, _, mean, std # TODO inelegant, replace by dataclass?