careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from fnmatch import fnmatch
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import tifffile
7
+
8
+ from careamics.config.support import SupportedData
9
+ from careamics.utils.logging import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
15
+ """
16
+ Read a tiff file and return a numpy array.
17
+
18
+ Parameters
19
+ ----------
20
+ file_path : Path
21
+ Path to a file.
22
+ axes : str
23
+ Description of axes in format STCZYX.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ Resulting array.
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If the file failed to open.
34
+ OSError
35
+ If the file failed to open.
36
+ ValueError
37
+ If the file is not a valid tiff.
38
+ ValueError
39
+ If the data dimensions are incorrect.
40
+ ValueError
41
+ If the axes length is incorrect.
42
+ """
43
+ if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
44
+ try:
45
+ array = tifffile.imread(file_path)
46
+ except (ValueError, OSError) as e:
47
+ logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
48
+ raise e
49
+ else:
50
+ raise ValueError(f"File {file_path} is not a valid tiff.")
51
+
52
+ # check dimensions
53
+ # TODO or should this really be done here? probably in the LightningDataModule
54
+ # TODO this should also be centralized somewhere else (validate_dimensions)
55
+ if len(array.shape) < 2 or len(array.shape) > 6:
56
+ raise ValueError(
57
+ f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
58
+ f"file {file_path})."
59
+ )
60
+
61
+ return array
@@ -0,0 +1,25 @@
1
+ from typing import Callable, Union
2
+
3
+ from careamics.config.support import SupportedData
4
+
5
+ from .read_tiff import read_tiff
6
+
7
+
8
+ def get_read_func(data_type: Union[SupportedData, str]) -> Callable:
9
+ """
10
+ Get the read function for the data type.
11
+
12
+ Parameters
13
+ ----------
14
+ data_type : SupportedData
15
+ Data type.
16
+
17
+ Returns
18
+ -------
19
+ Callable
20
+ Read function.
21
+ """
22
+ if data_type == SupportedData.TIFF:
23
+ return read_tiff
24
+ else:
25
+ raise NotImplementedError(f"Data type {data_type} is not supported.")
@@ -0,0 +1,56 @@
1
+ from typing import Union
2
+
3
+ from zarr import Group, core, hierarchy, storage
4
+
5
+
6
+ def read_zarr(
7
+ zarr_source: Group, axes: str
8
+ ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
9
+ """Reads a file and returns a pointer.
10
+
11
+ Parameters
12
+ ----------
13
+ file_path : Path
14
+ pathlib.Path object containing a path to a file
15
+
16
+ Returns
17
+ -------
18
+ np.ndarray
19
+ Pointer to zarr storage
20
+
21
+ Raises
22
+ ------
23
+ ValueError, OSError
24
+ if a file is not a valid tiff or damaged
25
+ ValueError
26
+ if data dimensions are not 2, 3 or 4
27
+ ValueError
28
+ if axes parameter from config is not consistent with data dimensions
29
+ """
30
+ if isinstance(zarr_source, hierarchy.Group):
31
+ array = zarr_source[0]
32
+
33
+ elif isinstance(zarr_source, storage.DirectoryStore):
34
+ raise NotImplementedError("DirectoryStore not supported yet")
35
+
36
+ elif isinstance(zarr_source, core.Array):
37
+ # array should be of shape (S, (C), (Z), Y, X), iterating over S ?
38
+ if zarr_source.dtype == "O":
39
+ raise NotImplementedError("Object type not supported yet")
40
+ else:
41
+ array = zarr_source
42
+ else:
43
+ raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")
44
+
45
+ # sanity check on dimensions
46
+ if len(array.shape) < 2 or len(array.shape) > 4:
47
+ raise ValueError(
48
+ f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
49
+ )
50
+
51
+ # sanity check on axes length
52
+ if len(axes) != len(array.shape):
53
+ raise ValueError(f"Incorrect axes length (got {axes}).")
54
+
55
+ # arr = fix_axes(arr, axes)
56
+ return array
@@ -1,154 +1,356 @@
1
1
  """In-memory dataset module."""
2
+ from __future__ import annotations
3
+
4
+ import copy
2
5
  from pathlib import Path
3
- from typing import Callable, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
7
 
5
8
  import numpy as np
6
- import torch
9
+ from torch.utils.data import Dataset
7
10
 
8
- from ..utils import normalize
11
+ from ..config import DataModel, InferenceModel
12
+ from ..config.tile_information import TileInformation
9
13
  from ..utils.logging import get_logger
10
- from .dataset_utils import (
11
- list_files,
12
- read_tiff,
14
+ from .dataset_utils import read_tiff, reshape_array
15
+ from .patching.patch_transform import get_patch_transform
16
+ from .patching.patching import (
17
+ prepare_patches_supervised,
18
+ prepare_patches_supervised_array,
19
+ prepare_patches_unsupervised,
20
+ prepare_patches_unsupervised_array,
13
21
  )
14
- from .extraction_strategy import ExtractionStrategy
15
- from .patching import generate_patches
22
+ from .patching.tiled_patching import extract_tiles
16
23
 
17
24
  logger = get_logger(__name__)
18
25
 
19
26
 
20
- class InMemoryDataset(torch.utils.data.Dataset):
27
+ class InMemoryDataset(Dataset):
28
+ """Dataset storing data in memory and allowing generating patches from it."""
29
+
30
+ def __init__(
31
+ self,
32
+ data_config: DataModel,
33
+ inputs: Union[np.ndarray, List[Path]],
34
+ data_target: Optional[Union[np.ndarray, List[Path]]] = None,
35
+ read_source_func: Callable = read_tiff,
36
+ **kwargs: Any,
37
+ ) -> None:
38
+ """
39
+ Constructor.
40
+
41
+ # TODO
42
+ """
43
+ self.data_config = data_config
44
+ self.inputs = inputs
45
+ self.data_target = data_target
46
+ self.axes = self.data_config.axes
47
+ self.patch_size = self.data_config.patch_size
48
+
49
+ # read function
50
+ self.read_source_func = read_source_func
51
+
52
+ # Generate patches
53
+ supervised = self.data_target is not None
54
+ patches = self._prepare_patches(supervised)
55
+
56
+ # Add results to members
57
+ self.data, self.data_targets, computed_mean, computed_std = patches
58
+
59
+ if not self.data_config.mean or not self.data_config.std:
60
+ self.mean, self.std = computed_mean, computed_std
61
+ logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
62
+
63
+ # if the transforms are not an instance of Compose
64
+ if self.data_config.has_transform_list():
65
+ # update mean and std in configuration
66
+ # the object is mutable and should then be recorded in the CAREamist obj
67
+ self.data_config.set_mean_and_std(self.mean, self.std)
68
+ else:
69
+ self.mean, self.std = self.data_config.mean, self.data_config.std
70
+
71
+ # get transforms
72
+ self.patch_transform = get_patch_transform(
73
+ patch_transforms=self.data_config.transforms,
74
+ with_target=self.data_target is not None,
75
+ )
76
+
77
+ def _prepare_patches(
78
+ self, supervised: bool
79
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
80
+ """
81
+ Iterate over data source and create an array of patches.
82
+
83
+ Parameters
84
+ ----------
85
+ supervised : bool
86
+ Whether the dataset is supervised or not.
87
+
88
+ Returns
89
+ -------
90
+ np.ndarray
91
+ Array of patches.
92
+ """
93
+ if supervised:
94
+ if isinstance(self.inputs, np.ndarray) and isinstance(
95
+ self.data_target, np.ndarray
96
+ ):
97
+ return prepare_patches_supervised_array(
98
+ self.inputs,
99
+ self.axes,
100
+ self.data_target,
101
+ self.patch_size,
102
+ )
103
+ elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
104
+ return prepare_patches_supervised(
105
+ self.inputs,
106
+ self.data_target,
107
+ self.axes,
108
+ self.patch_size,
109
+ self.read_source_func,
110
+ )
111
+ else:
112
+ raise ValueError(
113
+ f"Data and target must be of the same type, either both numpy "
114
+ f"arrays or both lists of paths, got {type(self.inputs)} (data) "
115
+ f"and {type(self.data_target)} (target)."
116
+ )
117
+ else:
118
+ if isinstance(self.inputs, np.ndarray):
119
+ return prepare_patches_unsupervised_array(
120
+ self.inputs,
121
+ self.axes,
122
+ self.patch_size,
123
+ )
124
+ else:
125
+ return prepare_patches_unsupervised(
126
+ self.inputs,
127
+ self.axes,
128
+ self.patch_size,
129
+ self.read_source_func,
130
+ )
131
+
132
+ def __len__(self) -> int:
133
+ """
134
+ Return the length of the dataset.
135
+
136
+ Returns
137
+ -------
138
+ int
139
+ Length of the dataset.
140
+ """
141
+ return len(self.data)
142
+
143
+ def __getitem__(self, index: int) -> Tuple[np.ndarray]:
144
+ """
145
+ Return the patch corresponding to the provided index.
146
+
147
+ Parameters
148
+ ----------
149
+ index : int
150
+ Index of the patch to return.
151
+
152
+ Returns
153
+ -------
154
+ Tuple[np.ndarray]
155
+ Patch.
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If dataset mean and std are not set.
161
+ """
162
+ patch = self.data[index]
163
+
164
+ # if there is a target
165
+ if self.data_target is not None:
166
+ # get target
167
+ target = self.data_targets[index]
168
+
169
+ # Albumentations requires Channel last
170
+ c_patch = np.moveaxis(patch, 0, -1)
171
+ c_target = np.moveaxis(target, 0, -1)
172
+
173
+ # Apply transforms
174
+ transformed = self.patch_transform(image=c_patch, target=c_target)
175
+
176
+ # move axes back
177
+ patch = np.moveaxis(transformed["image"], -1, 0)
178
+ target = np.moveaxis(transformed["target"], -1, 0)
179
+
180
+ return patch, target
181
+
182
+ elif self.data_config.has_n2v_manipulate():
183
+ # Albumentations requires Channel last
184
+ patch = np.moveaxis(patch, 0, -1)
185
+
186
+ # Apply transforms
187
+ transformed_patch = self.patch_transform(image=patch)["image"]
188
+ manip_patch, patch, mask = transformed_patch
189
+
190
+ # move C axes back
191
+ manip_patch = np.moveaxis(manip_patch, -1, 0)
192
+ patch = np.moveaxis(patch, -1, 0)
193
+ mask = np.moveaxis(mask, -1, 0)
194
+
195
+ return (manip_patch, patch, mask)
196
+ else:
197
+ raise ValueError(
198
+ "Something went wrong! No target provided (not supervised training) "
199
+ "and no N2V manipulation (no N2V training)."
200
+ )
201
+
202
+ def split_dataset(
203
+ self,
204
+ percentage: float = 0.1,
205
+ minimum_patches: int = 1,
206
+ ) -> InMemoryDataset:
207
+ """Split a new dataset away from the current one.
208
+
209
+ This method is used to extract random validation patches from the dataset.
210
+
211
+ Parameters
212
+ ----------
213
+ percentage : float, optional
214
+ Percentage of patches to extract, by default 0.1.
215
+ minimum_patches : int, optional
216
+ Minimum number of patches to extract, by default 5.
217
+
218
+ Returns
219
+ -------
220
+ InMemoryDataset
221
+ New dataset with the extracted patches.
222
+
223
+ Raises
224
+ ------
225
+ ValueError
226
+ If `percentage` is not between 0 and 1.
227
+ ValueError
228
+ If `minimum_number` is not between 1 and the number of patches.
229
+ """
230
+ if percentage < 0 or percentage > 1:
231
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
232
+
233
+ if minimum_patches < 1 or minimum_patches > len(self):
234
+ raise ValueError(
235
+ f"Minimum number of patches must be between 1 and "
236
+ f"{len(self)} (number of patches), got "
237
+ f"{minimum_patches}. Adjust the patch size or the minimum number of "
238
+ f"patches."
239
+ )
240
+
241
+ total_patches = len(self)
242
+
243
+ # number of patches to extract (either percentage rounded or minimum number)
244
+ n_patches = max(round(total_patches * percentage), minimum_patches)
245
+
246
+ # get random indices
247
+ indices = np.random.choice(total_patches, n_patches, replace=False)
248
+
249
+ # extract patches
250
+ val_patches = self.data[indices]
251
+
252
+ # remove patches from self.patch
253
+ self.data = np.delete(self.data, indices, axis=0)
254
+
255
+ # same for targets
256
+ if self.data_targets is not None:
257
+ val_targets = self.data_targets[indices]
258
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
259
+
260
+ # clone the dataset
261
+ dataset = copy.deepcopy(self)
262
+
263
+ # reassign patches
264
+ dataset.data = val_patches
265
+
266
+ # reassign targets
267
+ if self.data_targets is not None:
268
+ dataset.data_targets = val_targets
269
+
270
+ return dataset
271
+
272
+
273
+ class InMemoryPredictionDataset(Dataset):
21
274
  """
22
275
  Dataset storing data in memory and allowing generating patches from it.
23
276
 
24
- Parameters
25
- ----------
26
- data_path : Union[str, Path]
27
- Path to the data, must be a directory.
28
- data_format : str
29
- Extension of the data files, without period.
30
- axes : str
31
- Description of axes in format STCZYX.
32
- patch_extraction_method : ExtractionStrategies
33
- Patch extraction strategy, as defined in extraction_strategy.
34
- patch_size : Union[List[int], Tuple[int]]
35
- Size of the patches along each axis, must be of dimension 2 or 3.
36
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
37
- Overlap of the patches, must be of dimension 2 or 3, by default None.
38
- mean : Optional[float], optional
39
- Expected mean of the dataset, by default None.
40
- std : Optional[float], optional
41
- Expected standard deviation of the dataset, by default None.
42
- patch_transform : Optional[Callable], optional
43
- Patch transform to apply, by default None.
44
- patch_transform_params : Optional[Dict], optional
45
- Patch transform parameters, by default None.
277
+ # TODO
46
278
  """
47
279
 
48
280
  def __init__(
49
281
  self,
50
- data_path: Union[str, Path],
51
- data_format: str,
52
- axes: str,
53
- patch_extraction_method: ExtractionStrategy,
54
- patch_size: Union[List[int], Tuple[int]],
55
- patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
56
- mean: Optional[float] = None,
57
- std: Optional[float] = None,
58
- patch_transform: Optional[Callable] = None,
59
- patch_transform_params: Optional[Dict] = None,
282
+ prediction_config: InferenceModel,
283
+ inputs: np.ndarray,
284
+ data_target: Optional[np.ndarray] = None,
285
+ read_source_func: Optional[Callable] = read_tiff,
60
286
  ) -> None:
61
- """
62
- Constructor.
287
+ """Constructor.
63
288
 
64
289
  Parameters
65
290
  ----------
66
- data_path : Union[str, Path]
67
- Path to the data, must be a directory.
68
- data_format : str
69
- Extension of the data files, without period.
291
+ array : np.ndarray
292
+ Array containing the data.
70
293
  axes : str
71
294
  Description of axes in format STCZYX.
72
- patch_extraction_method : ExtractionStrategies
73
- Patch extraction strategy, as defined in extraction_strategy.
74
- patch_size : Union[List[int], Tuple[int]]
75
- Size of the patches along each axis, must be of dimension 2 or 3.
76
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
77
- Overlap of the patches, must be of dimension 2 or 3, by default None.
78
- mean : Optional[float], optional
79
- Expected mean of the dataset, by default None.
80
- std : Optional[float], optional
81
- Expected standard deviation of the dataset, by default None.
82
- patch_transform : Optional[Callable], optional
83
- Patch transform to apply, by default None.
84
- patch_transform_params : Optional[Dict], optional
85
- Patch transform parameters, by default None.
86
295
 
87
296
  Raises
88
297
  ------
89
298
  ValueError
90
299
  If data_path is not a directory.
91
300
  """
92
- self.data_path = Path(data_path)
93
- if not self.data_path.is_dir():
94
- raise ValueError("Path to data should be an existing folder.")
95
-
96
- self.data_format = data_format
97
- self.axes = axes
98
-
99
- self.patch_transform = patch_transform
301
+ self.pred_config = prediction_config
302
+ self.input_array = inputs
303
+ self.axes = self.pred_config.axes
304
+ self.tile_size = self.pred_config.tile_size
305
+ self.tile_overlap = self.pred_config.tile_overlap
306
+ self.mean = self.pred_config.mean
307
+ self.std = self.pred_config.std
308
+ self.data_target = data_target
100
309
 
101
- self.files = list_files(self.data_path, self.data_format)
310
+ # tiling only if both tile size and overlap are provided
311
+ self.tiling = self.tile_size is not None and self.tile_overlap is not None
102
312
 
103
- self.patch_size = patch_size
104
- self.patch_overlap = patch_overlap
105
- self.patch_extraction_method = patch_extraction_method
106
- self.patch_transform = patch_transform
107
- self.patch_transform_params = patch_transform_params
108
-
109
- self.mean = mean
110
- self.std = std
313
+ # read function
314
+ self.read_source_func = read_source_func
111
315
 
112
316
  # Generate patches
113
- self.data, computed_mean, computed_std = self._prepare_patches()
114
-
115
- if not mean or not std:
116
- self.mean, self.std = computed_mean, computed_std
117
- logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
317
+ self.data = self._prepare_tiles()
318
+ self.mean, self.std = self.pred_config.mean, self.pred_config.std
118
319
 
119
- assert self.mean is not None
120
- assert self.std is not None
320
+ # get transforms
321
+ self.patch_transform = get_patch_transform(
322
+ patch_transforms=self.pred_config.transforms,
323
+ with_target=self.data_target is not None,
324
+ )
121
325
 
122
- def _prepare_patches(self) -> Tuple[np.ndarray, float, float]:
326
+ def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
123
327
  """
124
328
  Iterate over data source and create an array of patches.
125
329
 
126
330
  Returns
127
331
  -------
128
- np.ndarray
129
- Array of patches.
332
+ List[XArrayTile]
333
+ List of tiles.
130
334
  """
131
- means, stds, num_samples = 0, 0, 0
132
- self.all_patches = []
133
- for filename in self.files:
134
- sample = read_tiff(filename, self.axes)
135
- means += sample.mean()
136
- stds += np.std(sample)
137
- num_samples += 1
138
-
139
- # generate patches, return a generator
140
- patches = generate_patches(
141
- sample,
142
- self.patch_extraction_method,
143
- self.patch_size,
144
- self.patch_overlap,
335
+ # reshape array
336
+ reshaped_sample = reshape_array(self.input_array, self.axes)
337
+
338
+ if self.tiling:
339
+ # generate patches, which returns a generator
340
+ patch_generator = extract_tiles(
341
+ arr=reshaped_sample,
342
+ tile_size=self.tile_size,
343
+ overlaps=self.tile_overlap,
145
344
  )
345
+ patches_list = list(patch_generator)
146
346
 
147
- # convert generator to list and add to all_patches
148
- self.all_patches.extend(list(patches))
347
+ if len(patches_list) == 0:
348
+ raise ValueError("No tiles generated, ")
149
349
 
150
- result_mean, result_std = means / num_samples, stds / num_samples
151
- return np.concatenate(self.all_patches), result_mean, result_std
350
+ return patches_list
351
+ else:
352
+ array_shape = reshaped_sample.squeeze().shape
353
+ return [(reshaped_sample, TileInformation(array_shape=array_shape))]
152
354
 
153
355
  def __len__(self) -> int:
154
356
  """
@@ -159,10 +361,9 @@ class InMemoryDataset(torch.utils.data.Dataset):
159
361
  int
160
362
  Length of the dataset.
161
363
  """
162
- # convert to numpy array to convince mypy that it is not a generator
163
- return sum(np.array(s).shape[0] for s in self.all_patches)
364
+ return len(self.data)
164
365
 
165
- def __getitem__(self, index: int) -> Tuple[np.ndarray]:
366
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
166
367
  """
167
368
  Return the patch corresponding to the provided index.
168
369
 
@@ -173,29 +374,18 @@ class InMemoryDataset(torch.utils.data.Dataset):
173
374
 
174
375
  Returns
175
376
  -------
176
- Tuple[np.ndarray]
177
- Patch.
178
-
179
- Raises
180
- ------
181
- ValueError
182
- If dataset mean and std are not set.
377
+ Tuple[np.ndarray, TileInformation]
378
+ Transformed patch.
183
379
  """
184
- patch = self.data[index].squeeze()
380
+ tile_array, tile_info = self.data[index]
185
381
 
186
- if self.mean is not None and self.std is not None:
187
- if isinstance(patch, tuple):
188
- patch = normalize(img=patch[0], mean=self.mean, std=self.std)
189
- patch = (patch, *patch[1:])
190
- else:
191
- patch = normalize(img=patch, mean=self.mean, std=self.std)
382
+ # Albumentations requires channel last, use the XArrayTile array
383
+ patch = np.moveaxis(tile_array, 0, -1)
192
384
 
193
- if self.patch_transform is not None:
194
- # replace None self.patch_transform_params with empty dict
195
- if self.patch_transform_params is None:
196
- self.patch_transform_params = {}
385
+ # Apply transforms
386
+ transformed_patch = self.patch_transform(image=patch)["image"]
197
387
 
198
- patch = self.patch_transform(patch, **self.patch_transform_params)
199
- return patch
200
- else:
201
- raise ValueError("Dataset mean and std must be set before using it.")
388
+ # move C axes back
389
+ transformed_patch = np.moveaxis(transformed_patch, -1, 0)
390
+
391
+ return transformed_patch, tile_info