careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from fnmatch import fnmatch
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import tifffile
7
+
8
+ from careamics.config.support import SupportedData
9
+ from careamics.utils.logging import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
15
+ """
16
+ Read a tiff file and return a numpy array.
17
+
18
+ Parameters
19
+ ----------
20
+ file_path : Path
21
+ Path to a file.
22
+ axes : str
23
+ Description of axes in format STCZYX.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ Resulting array.
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If the file failed to open.
34
+ OSError
35
+ If the file failed to open.
36
+ ValueError
37
+ If the file is not a valid tiff.
38
+ ValueError
39
+ If the data dimensions are incorrect.
40
+ ValueError
41
+ If the axes length is incorrect.
42
+ """
43
+ if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
44
+ try:
45
+ array = tifffile.imread(file_path)
46
+ except (ValueError, OSError) as e:
47
+ logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
48
+ raise e
49
+ else:
50
+ raise ValueError(f"File {file_path} is not a valid tiff.")
51
+
52
+ # check dimensions
53
+ # TODO or should this really be done here? probably in the LightningDataModule
54
+ # TODO this should also be centralized somewhere else (validate_dimensions)
55
+ if len(array.shape) < 2 or len(array.shape) > 6:
56
+ raise ValueError(
57
+ f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
58
+ f"file {file_path})."
59
+ )
60
+
61
+ return array
@@ -0,0 +1,25 @@
1
+ from typing import Callable, Union
2
+
3
+ from careamics.config.support import SupportedData
4
+
5
+ from .read_tiff import read_tiff
6
+
7
+
8
+ def get_read_func(data_type: Union[SupportedData, str]) -> Callable:
9
+ """
10
+ Get the read function for the data type.
11
+
12
+ Parameters
13
+ ----------
14
+ data_type : SupportedData
15
+ Data type.
16
+
17
+ Returns
18
+ -------
19
+ Callable
20
+ Read function.
21
+ """
22
+ if data_type == SupportedData.TIFF:
23
+ return read_tiff
24
+ else:
25
+ raise NotImplementedError(f"Data type {data_type} is not supported.")
@@ -0,0 +1,56 @@
1
+ from typing import Union
2
+
3
+ from zarr import Group, core, hierarchy, storage
4
+
5
+
6
+ def read_zarr(
7
+ zarr_source: Group, axes: str
8
+ ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
9
+ """Reads a file and returns a pointer.
10
+
11
+ Parameters
12
+ ----------
13
+ file_path : Path
14
+ pathlib.Path object containing a path to a file
15
+
16
+ Returns
17
+ -------
18
+ np.ndarray
19
+ Pointer to zarr storage
20
+
21
+ Raises
22
+ ------
23
+ ValueError, OSError
24
+ if a file is not a valid tiff or damaged
25
+ ValueError
26
+ if data dimensions are not 2, 3 or 4
27
+ ValueError
28
+ if axes parameter from config is not consistent with data dimensions
29
+ """
30
+ if isinstance(zarr_source, hierarchy.Group):
31
+ array = zarr_source[0]
32
+
33
+ elif isinstance(zarr_source, storage.DirectoryStore):
34
+ raise NotImplementedError("DirectoryStore not supported yet")
35
+
36
+ elif isinstance(zarr_source, core.Array):
37
+ # array should be of shape (S, (C), (Z), Y, X), iterating over S ?
38
+ if zarr_source.dtype == "O":
39
+ raise NotImplementedError("Object type not supported yet")
40
+ else:
41
+ array = zarr_source
42
+ else:
43
+ raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")
44
+
45
+ # sanity check on dimensions
46
+ if len(array.shape) < 2 or len(array.shape) > 4:
47
+ raise ValueError(
48
+ f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
49
+ )
50
+
51
+ # sanity check on axes length
52
+ if len(axes) != len(array.shape):
53
+ raise ValueError(f"Incorrect axes length (got {axes}).")
54
+
55
+ # arr = fix_axes(arr, axes)
56
+ return array
@@ -1,155 +1,356 @@
1
1
  """In-memory dataset module."""
2
+ from __future__ import annotations
3
+
4
+ import copy
2
5
  from pathlib import Path
3
- from typing import Callable, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
7
 
5
8
  import numpy as np
6
- import torch
7
-
8
- from careamics.utils import normalize
9
- from careamics.utils.logging import get_logger
10
-
11
- from .dataset_utils import (
12
- list_files,
13
- read_tiff,
9
+ from torch.utils.data import Dataset
10
+
11
+ from ..config import DataConfig, InferenceConfig
12
+ from ..config.tile_information import TileInformation
13
+ from ..utils.logging import get_logger
14
+ from .dataset_utils import read_tiff, reshape_array
15
+ from .patching.patch_transform import get_patch_transform
16
+ from .patching.patching import (
17
+ prepare_patches_supervised,
18
+ prepare_patches_supervised_array,
19
+ prepare_patches_unsupervised,
20
+ prepare_patches_unsupervised_array,
14
21
  )
15
- from .extraction_strategy import ExtractionStrategy
16
- from .patching import generate_patches
22
+ from .patching.tiled_patching import extract_tiles
17
23
 
18
24
  logger = get_logger(__name__)
19
25
 
20
26
 
21
- class InMemoryDataset(torch.utils.data.Dataset):
27
+ class InMemoryDataset(Dataset):
28
+ """Dataset storing data in memory and allowing generating patches from it."""
29
+
30
+ def __init__(
31
+ self,
32
+ data_config: DataConfig,
33
+ inputs: Union[np.ndarray, List[Path]],
34
+ data_target: Optional[Union[np.ndarray, List[Path]]] = None,
35
+ read_source_func: Callable = read_tiff,
36
+ **kwargs: Any,
37
+ ) -> None:
38
+ """
39
+ Constructor.
40
+
41
+ # TODO
42
+ """
43
+ self.data_config = data_config
44
+ self.inputs = inputs
45
+ self.data_target = data_target
46
+ self.axes = self.data_config.axes
47
+ self.patch_size = self.data_config.patch_size
48
+
49
+ # read function
50
+ self.read_source_func = read_source_func
51
+
52
+ # Generate patches
53
+ supervised = self.data_target is not None
54
+ patches = self._prepare_patches(supervised)
55
+
56
+ # Add results to members
57
+ self.data, self.data_targets, computed_mean, computed_std = patches
58
+
59
+ if not self.data_config.mean or not self.data_config.std:
60
+ self.mean, self.std = computed_mean, computed_std
61
+ logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
62
+
63
+ # if the transforms are not an instance of Compose
64
+ if self.data_config.has_transform_list():
65
+ # update mean and std in configuration
66
+ # the object is mutable and should then be recorded in the CAREamist obj
67
+ self.data_config.set_mean_and_std(self.mean, self.std)
68
+ else:
69
+ self.mean, self.std = self.data_config.mean, self.data_config.std
70
+
71
+ # get transforms
72
+ self.patch_transform = get_patch_transform(
73
+ patch_transforms=self.data_config.transforms,
74
+ with_target=self.data_target is not None,
75
+ )
76
+
77
+ def _prepare_patches(
78
+ self, supervised: bool
79
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
80
+ """
81
+ Iterate over data source and create an array of patches.
82
+
83
+ Parameters
84
+ ----------
85
+ supervised : bool
86
+ Whether the dataset is supervised or not.
87
+
88
+ Returns
89
+ -------
90
+ np.ndarray
91
+ Array of patches.
92
+ """
93
+ if supervised:
94
+ if isinstance(self.inputs, np.ndarray) and isinstance(
95
+ self.data_target, np.ndarray
96
+ ):
97
+ return prepare_patches_supervised_array(
98
+ self.inputs,
99
+ self.axes,
100
+ self.data_target,
101
+ self.patch_size,
102
+ )
103
+ elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
104
+ return prepare_patches_supervised(
105
+ self.inputs,
106
+ self.data_target,
107
+ self.axes,
108
+ self.patch_size,
109
+ self.read_source_func,
110
+ )
111
+ else:
112
+ raise ValueError(
113
+ f"Data and target must be of the same type, either both numpy "
114
+ f"arrays or both lists of paths, got {type(self.inputs)} (data) "
115
+ f"and {type(self.data_target)} (target)."
116
+ )
117
+ else:
118
+ if isinstance(self.inputs, np.ndarray):
119
+ return prepare_patches_unsupervised_array(
120
+ self.inputs,
121
+ self.axes,
122
+ self.patch_size,
123
+ )
124
+ else:
125
+ return prepare_patches_unsupervised(
126
+ self.inputs,
127
+ self.axes,
128
+ self.patch_size,
129
+ self.read_source_func,
130
+ )
131
+
132
+ def __len__(self) -> int:
133
+ """
134
+ Return the length of the dataset.
135
+
136
+ Returns
137
+ -------
138
+ int
139
+ Length of the dataset.
140
+ """
141
+ return len(self.data)
142
+
143
+ def __getitem__(self, index: int) -> Tuple[np.ndarray]:
144
+ """
145
+ Return the patch corresponding to the provided index.
146
+
147
+ Parameters
148
+ ----------
149
+ index : int
150
+ Index of the patch to return.
151
+
152
+ Returns
153
+ -------
154
+ Tuple[np.ndarray]
155
+ Patch.
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If dataset mean and std are not set.
161
+ """
162
+ patch = self.data[index]
163
+
164
+ # if there is a target
165
+ if self.data_target is not None:
166
+ # get target
167
+ target = self.data_targets[index]
168
+
169
+ # Albumentations requires Channel last
170
+ c_patch = np.moveaxis(patch, 0, -1)
171
+ c_target = np.moveaxis(target, 0, -1)
172
+
173
+ # Apply transforms
174
+ transformed = self.patch_transform(image=c_patch, target=c_target)
175
+
176
+ # move axes back
177
+ patch = np.moveaxis(transformed["image"], -1, 0)
178
+ target = np.moveaxis(transformed["target"], -1, 0)
179
+
180
+ return patch, target
181
+
182
+ elif self.data_config.has_n2v_manipulate():
183
+ # Albumentations requires Channel last
184
+ patch = np.moveaxis(patch, 0, -1)
185
+
186
+ # Apply transforms
187
+ transformed_patch = self.patch_transform(image=patch)["image"]
188
+ manip_patch, patch, mask = transformed_patch
189
+
190
+ # move C axes back
191
+ manip_patch = np.moveaxis(manip_patch, -1, 0)
192
+ patch = np.moveaxis(patch, -1, 0)
193
+ mask = np.moveaxis(mask, -1, 0)
194
+
195
+ return (manip_patch, patch, mask)
196
+ else:
197
+ raise ValueError(
198
+ "Something went wrong! No target provided (not supervised training) "
199
+ "and no N2V manipulation (no N2V training)."
200
+ )
201
+
202
+ def split_dataset(
203
+ self,
204
+ percentage: float = 0.1,
205
+ minimum_patches: int = 1,
206
+ ) -> InMemoryDataset:
207
+ """Split a new dataset away from the current one.
208
+
209
+ This method is used to extract random validation patches from the dataset.
210
+
211
+ Parameters
212
+ ----------
213
+ percentage : float, optional
214
+ Percentage of patches to extract, by default 0.1.
215
+ minimum_patches : int, optional
216
+ Minimum number of patches to extract, by default 5.
217
+
218
+ Returns
219
+ -------
220
+ InMemoryDataset
221
+ New dataset with the extracted patches.
222
+
223
+ Raises
224
+ ------
225
+ ValueError
226
+ If `percentage` is not between 0 and 1.
227
+ ValueError
228
+ If `minimum_number` is not between 1 and the number of patches.
229
+ """
230
+ if percentage < 0 or percentage > 1:
231
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
232
+
233
+ if minimum_patches < 1 or minimum_patches > len(self):
234
+ raise ValueError(
235
+ f"Minimum number of patches must be between 1 and "
236
+ f"{len(self)} (number of patches), got "
237
+ f"{minimum_patches}. Adjust the patch size or the minimum number of "
238
+ f"patches."
239
+ )
240
+
241
+ total_patches = len(self)
242
+
243
+ # number of patches to extract (either percentage rounded or minimum number)
244
+ n_patches = max(round(total_patches * percentage), minimum_patches)
245
+
246
+ # get random indices
247
+ indices = np.random.choice(total_patches, n_patches, replace=False)
248
+
249
+ # extract patches
250
+ val_patches = self.data[indices]
251
+
252
+ # remove patches from self.patch
253
+ self.data = np.delete(self.data, indices, axis=0)
254
+
255
+ # same for targets
256
+ if self.data_targets is not None:
257
+ val_targets = self.data_targets[indices]
258
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
259
+
260
+ # clone the dataset
261
+ dataset = copy.deepcopy(self)
262
+
263
+ # reassign patches
264
+ dataset.data = val_patches
265
+
266
+ # reassign targets
267
+ if self.data_targets is not None:
268
+ dataset.data_targets = val_targets
269
+
270
+ return dataset
271
+
272
+
273
+ class InMemoryPredictionDataset(Dataset):
22
274
  """
23
275
  Dataset storing data in memory and allowing generating patches from it.
24
276
 
25
- Parameters
26
- ----------
27
- data_path : Union[str, Path]
28
- Path to the data, must be a directory.
29
- data_format : str
30
- Extension of the data files, without period.
31
- axes : str
32
- Description of axes in format STCZYX.
33
- patch_extraction_method : ExtractionStrategies
34
- Patch extraction strategy, as defined in extraction_strategy.
35
- patch_size : Union[List[int], Tuple[int]]
36
- Size of the patches along each axis, must be of dimension 2 or 3.
37
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
38
- Overlap of the patches, must be of dimension 2 or 3, by default None.
39
- mean : Optional[float], optional
40
- Expected mean of the dataset, by default None.
41
- std : Optional[float], optional
42
- Expected standard deviation of the dataset, by default None.
43
- patch_transform : Optional[Callable], optional
44
- Patch transform to apply, by default None.
45
- patch_transform_params : Optional[Dict], optional
46
- Patch transform parameters, by default None.
277
+ # TODO
47
278
  """
48
279
 
49
280
  def __init__(
50
281
  self,
51
- data_path: Union[str, Path],
52
- data_format: str,
53
- axes: str,
54
- patch_extraction_method: ExtractionStrategy,
55
- patch_size: Union[List[int], Tuple[int]],
56
- patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
57
- mean: Optional[float] = None,
58
- std: Optional[float] = None,
59
- patch_transform: Optional[Callable] = None,
60
- patch_transform_params: Optional[Dict] = None,
282
+ prediction_config: InferenceConfig,
283
+ inputs: np.ndarray,
284
+ data_target: Optional[np.ndarray] = None,
285
+ read_source_func: Optional[Callable] = read_tiff,
61
286
  ) -> None:
62
- """
63
- Constructor.
287
+ """Constructor.
64
288
 
65
289
  Parameters
66
290
  ----------
67
- data_path : Union[str, Path]
68
- Path to the data, must be a directory.
69
- data_format : str
70
- Extension of the data files, without period.
291
+ array : np.ndarray
292
+ Array containing the data.
71
293
  axes : str
72
294
  Description of axes in format STCZYX.
73
- patch_extraction_method : ExtractionStrategies
74
- Patch extraction strategy, as defined in extraction_strategy.
75
- patch_size : Union[List[int], Tuple[int]]
76
- Size of the patches along each axis, must be of dimension 2 or 3.
77
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
78
- Overlap of the patches, must be of dimension 2 or 3, by default None.
79
- mean : Optional[float], optional
80
- Expected mean of the dataset, by default None.
81
- std : Optional[float], optional
82
- Expected standard deviation of the dataset, by default None.
83
- patch_transform : Optional[Callable], optional
84
- Patch transform to apply, by default None.
85
- patch_transform_params : Optional[Dict], optional
86
- Patch transform parameters, by default None.
87
295
 
88
296
  Raises
89
297
  ------
90
298
  ValueError
91
299
  If data_path is not a directory.
92
300
  """
93
- self.data_path = Path(data_path)
94
- if not self.data_path.is_dir():
95
- raise ValueError("Path to data should be an existing folder.")
96
-
97
- self.data_format = data_format
98
- self.axes = axes
301
+ self.pred_config = prediction_config
302
+ self.input_array = inputs
303
+ self.axes = self.pred_config.axes
304
+ self.tile_size = self.pred_config.tile_size
305
+ self.tile_overlap = self.pred_config.tile_overlap
306
+ self.mean = self.pred_config.mean
307
+ self.std = self.pred_config.std
308
+ self.data_target = data_target
99
309
 
100
- self.patch_transform = patch_transform
310
+ # tiling only if both tile size and overlap are provided
311
+ self.tiling = self.tile_size is not None and self.tile_overlap is not None
101
312
 
102
- self.files = list_files(self.data_path, self.data_format)
103
-
104
- self.patch_size = patch_size
105
- self.patch_overlap = patch_overlap
106
- self.patch_extraction_method = patch_extraction_method
107
- self.patch_transform = patch_transform
108
- self.patch_transform_params = patch_transform_params
109
-
110
- self.mean = mean
111
- self.std = std
313
+ # read function
314
+ self.read_source_func = read_source_func
112
315
 
113
316
  # Generate patches
114
- self.data, computed_mean, computed_std = self._prepare_patches()
317
+ self.data = self._prepare_tiles()
318
+ self.mean, self.std = self.pred_config.mean, self.pred_config.std
115
319
 
116
- if not mean or not std:
117
- self.mean, self.std = computed_mean, computed_std
118
- logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
320
+ # get transforms
321
+ self.patch_transform = get_patch_transform(
322
+ patch_transforms=self.pred_config.transforms,
323
+ with_target=self.data_target is not None,
324
+ )
119
325
 
120
- assert self.mean is not None
121
- assert self.std is not None
122
-
123
- def _prepare_patches(self) -> Tuple[np.ndarray, float, float]:
326
+ def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
124
327
  """
125
328
  Iterate over data source and create an array of patches.
126
329
 
127
330
  Returns
128
331
  -------
129
- np.ndarray
130
- Array of patches.
332
+ List[XArrayTile]
333
+ List of tiles.
131
334
  """
132
- means, stds, num_samples = 0, 0, 0
133
- self.all_patches = []
134
- for filename in self.files:
135
- sample = read_tiff(filename, self.axes)
136
- means += sample.mean()
137
- stds += np.std(sample)
138
- num_samples += 1
139
-
140
- # generate patches, return a generator
141
- patches = generate_patches(
142
- sample,
143
- self.patch_extraction_method,
144
- self.patch_size,
145
- self.patch_overlap,
335
+ # reshape array
336
+ reshaped_sample = reshape_array(self.input_array, self.axes)
337
+
338
+ if self.tiling:
339
+ # generate patches, which returns a generator
340
+ patch_generator = extract_tiles(
341
+ arr=reshaped_sample,
342
+ tile_size=self.tile_size,
343
+ overlaps=self.tile_overlap,
146
344
  )
345
+ patches_list = list(patch_generator)
147
346
 
148
- # convert generator to list and add to all_patches
149
- self.all_patches.extend(list(patches))
347
+ if len(patches_list) == 0:
348
+ raise ValueError("No tiles generated, ")
150
349
 
151
- result_mean, result_std = means / num_samples, stds / num_samples
152
- return np.concatenate(self.all_patches), result_mean, result_std
350
+ return patches_list
351
+ else:
352
+ array_shape = reshaped_sample.squeeze().shape
353
+ return [(reshaped_sample, TileInformation(array_shape=array_shape))]
153
354
 
154
355
  def __len__(self) -> int:
155
356
  """
@@ -160,10 +361,9 @@ class InMemoryDataset(torch.utils.data.Dataset):
160
361
  int
161
362
  Length of the dataset.
162
363
  """
163
- # convert to numpy array to convince mypy that it is not a generator
164
- return sum(np.array(s).shape[0] for s in self.all_patches)
364
+ return len(self.data)
165
365
 
166
- def __getitem__(self, index: int) -> Tuple[np.ndarray]:
366
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
167
367
  """
168
368
  Return the patch corresponding to the provided index.
169
369
 
@@ -174,29 +374,18 @@ class InMemoryDataset(torch.utils.data.Dataset):
174
374
 
175
375
  Returns
176
376
  -------
177
- Tuple[np.ndarray]
178
- Patch.
179
-
180
- Raises
181
- ------
182
- ValueError
183
- If dataset mean and std are not set.
377
+ Tuple[np.ndarray, TileInformation]
378
+ Transformed patch.
184
379
  """
185
- patch = self.data[index].squeeze()
380
+ tile_array, tile_info = self.data[index]
186
381
 
187
- if self.mean is not None and self.std is not None:
188
- if isinstance(patch, tuple):
189
- patch = normalize(img=patch[0], mean=self.mean, std=self.std)
190
- patch = (patch, *patch[1:])
191
- else:
192
- patch = normalize(img=patch, mean=self.mean, std=self.std)
382
+ # Albumentations requires channel last, use the XArrayTile array
383
+ patch = np.moveaxis(tile_array, 0, -1)
193
384
 
194
- if self.patch_transform is not None:
195
- # replace None self.patch_transform_params with empty dict
196
- if self.patch_transform_params is None:
197
- self.patch_transform_params = {}
385
+ # Apply transforms
386
+ transformed_patch = self.patch_transform(image=patch)["image"]
198
387
 
199
- patch = self.patch_transform(patch, **self.patch_transform_params)
200
- return patch
201
- else:
202
- raise ValueError("Dataset mean and std must be set before using it.")
388
+ # move C axes back
389
+ transformed_patch = np.moveaxis(transformed_patch, -1, 0)
390
+
391
+ return transformed_patch, tile_info