careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,310 @@
1
+ """In-memory dataset module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Optional, Union
8
+
9
+ import numpy as np
10
+ from torch.utils.data import Dataset
11
+
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.transforms import Compose
14
+
15
+ from ..config import DataConfig
16
+ from ..config.transformations import NormalizeModel
17
+ from ..utils.logging import get_logger
18
+ from .patching.patching import (
19
+ PatchedOutput,
20
+ Stats,
21
+ prepare_patches_supervised,
22
+ prepare_patches_supervised_array,
23
+ prepare_patches_unsupervised,
24
+ prepare_patches_unsupervised_array,
25
+ )
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class InMemoryDataset(Dataset):
31
+ """Dataset storing data in memory and allowing generating patches from it.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : CAREamics DataConfig
36
+ (see careamics.config.data_model.DataConfig)
37
+ Data configuration.
38
+ inputs : numpy.ndarray or list[pathlib.Path]
39
+ Input data.
40
+ input_target : numpy.ndarray or list[pathlib.Path], optional
41
+ Target data, by default None.
42
+ read_source_func : Callable, optional
43
+ Read source function for custom types, by default read_tiff.
44
+ **kwargs : Any
45
+ Additional keyword arguments, unused.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ data_config: DataConfig,
51
+ inputs: Union[np.ndarray, list[Path]],
52
+ input_target: Optional[Union[np.ndarray, list[Path]]] = None,
53
+ read_source_func: Callable = read_tiff,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ """
57
+ Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ data_config : DataConfig
62
+ Data configuration.
63
+ inputs : numpy.ndarray or list[pathlib.Path]
64
+ Input data.
65
+ input_target : numpy.ndarray or list[pathlib.Path], optional
66
+ Target data, by default None.
67
+ read_source_func : Callable, optional
68
+ Read source function for custom types, by default read_tiff.
69
+ **kwargs : Any
70
+ Additional keyword arguments, unused.
71
+ """
72
+ self.data_config = data_config
73
+ self.inputs = inputs
74
+ self.input_targets = input_target
75
+ self.axes = self.data_config.axes
76
+ self.patch_size = self.data_config.patch_size
77
+
78
+ # read function
79
+ self.read_source_func = read_source_func
80
+
81
+ # generate patches
82
+ supervised = self.input_targets is not None
83
+ patches_data = self._prepare_patches(supervised)
84
+
85
+ # unpack the dataclass
86
+ self.data = patches_data.patches
87
+ self.data_targets = patches_data.targets
88
+
89
+ # set image statistics
90
+ if self.data_config.image_means is None:
91
+ self.image_stats = patches_data.image_stats
92
+ logger.info(
93
+ f"Computed dataset mean: {self.image_stats.means}, "
94
+ f"std: {self.image_stats.stds}"
95
+ )
96
+ else:
97
+ self.image_stats = Stats(
98
+ self.data_config.image_means, self.data_config.image_stds
99
+ )
100
+
101
+ # set target statistics
102
+ if self.data_config.target_means is None:
103
+ self.target_stats = patches_data.target_stats
104
+ else:
105
+ self.target_stats = Stats(
106
+ self.data_config.target_means, self.data_config.target_stds
107
+ )
108
+
109
+ # update mean and std in configuration
110
+ # the object is mutable and should then be recorded in the CAREamist obj
111
+ self.data_config.set_means_and_stds(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
116
+ )
117
+ # get transforms
118
+ self.patch_transform = Compose(
119
+ transform_list=[
120
+ NormalizeModel(
121
+ image_means=self.image_stats.means,
122
+ image_stds=self.image_stats.stds,
123
+ target_means=self.target_stats.means,
124
+ target_stds=self.target_stats.stds,
125
+ )
126
+ ]
127
+ + self.data_config.transforms,
128
+ )
129
+
130
+ def _prepare_patches(self, supervised: bool) -> PatchedOutput:
131
+ """
132
+ Iterate over data source and create an array of patches.
133
+
134
+ Parameters
135
+ ----------
136
+ supervised : bool
137
+ Whether the dataset is supervised or not.
138
+
139
+ Returns
140
+ -------
141
+ numpy.ndarray
142
+ Array of patches.
143
+ """
144
+ if supervised:
145
+ if isinstance(self.inputs, np.ndarray) and isinstance(
146
+ self.input_targets, np.ndarray
147
+ ):
148
+ return prepare_patches_supervised_array(
149
+ self.inputs,
150
+ self.axes,
151
+ self.input_targets,
152
+ self.patch_size,
153
+ )
154
+ elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
155
+ return prepare_patches_supervised(
156
+ self.inputs,
157
+ self.input_targets,
158
+ self.axes,
159
+ self.patch_size,
160
+ self.read_source_func,
161
+ )
162
+ else:
163
+ raise ValueError(
164
+ f"Data and target must be of the same type, either both numpy "
165
+ f"arrays or both lists of paths, got {type(self.inputs)} (data) "
166
+ f"and {type(self.input_targets)} (target)."
167
+ )
168
+ else:
169
+ if isinstance(self.inputs, np.ndarray):
170
+ return prepare_patches_unsupervised_array(
171
+ self.inputs,
172
+ self.axes,
173
+ self.patch_size,
174
+ )
175
+ else:
176
+ return prepare_patches_unsupervised(
177
+ self.inputs,
178
+ self.axes,
179
+ self.patch_size,
180
+ self.read_source_func,
181
+ )
182
+
183
+ def __len__(self) -> int:
184
+ """
185
+ Return the length of the dataset.
186
+
187
+ Returns
188
+ -------
189
+ int
190
+ Length of the dataset.
191
+ """
192
+ return self.data.shape[0]
193
+
194
+ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
195
+ """
196
+ Return the patch corresponding to the provided index.
197
+
198
+ Parameters
199
+ ----------
200
+ index : int
201
+ Index of the patch to return.
202
+
203
+ Returns
204
+ -------
205
+ tuple of numpy.ndarray
206
+ Patch.
207
+
208
+ Raises
209
+ ------
210
+ ValueError
211
+ If dataset mean and std are not set.
212
+ """
213
+ patch = self.data[index]
214
+
215
+ # if there is a target
216
+ if self.data_targets is not None:
217
+ # get target
218
+ target = self.data_targets[index]
219
+
220
+ return self.patch_transform(patch=patch, target=target)
221
+
222
+ elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
223
+ return self.patch_transform(patch=patch)
224
+ else:
225
+ raise ValueError(
226
+ "Something went wrong! No target provided (not supervised training) "
227
+ "and no N2V manipulation (no N2V training)."
228
+ )
229
+
230
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
231
+ """Return training data statistics.
232
+
233
+ This does not return the target data statistics, only those of the input.
234
+
235
+ Returns
236
+ -------
237
+ tuple of list of floats
238
+ Means and standard deviations across channels of the training data.
239
+ """
240
+ return self.image_stats.get_statistics()
241
+
242
+ def split_dataset(
243
+ self,
244
+ percentage: float = 0.1,
245
+ minimum_patches: int = 1,
246
+ ) -> InMemoryDataset:
247
+ """Split a new dataset away from the current one.
248
+
249
+ This method is used to extract random validation patches from the dataset.
250
+
251
+ Parameters
252
+ ----------
253
+ percentage : float, optional
254
+ Percentage of patches to extract, by default 0.1.
255
+ minimum_patches : int, optional
256
+ Minimum number of patches to extract, by default 5.
257
+
258
+ Returns
259
+ -------
260
+ CAREamics InMemoryDataset
261
+ New dataset with the extracted patches.
262
+
263
+ Raises
264
+ ------
265
+ ValueError
266
+ If `percentage` is not between 0 and 1.
267
+ ValueError
268
+ If `minimum_number` is not between 1 and the number of patches.
269
+ """
270
+ if percentage < 0 or percentage > 1:
271
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
272
+
273
+ if minimum_patches < 1 or minimum_patches > len(self):
274
+ raise ValueError(
275
+ f"Minimum number of patches must be between 1 and "
276
+ f"{len(self)} (number of patches), got "
277
+ f"{minimum_patches}. Adjust the patch size or the minimum number of "
278
+ f"patches."
279
+ )
280
+
281
+ total_patches = len(self)
282
+
283
+ # number of patches to extract (either percentage rounded or minimum number)
284
+ n_patches = max(round(total_patches * percentage), minimum_patches)
285
+
286
+ # get random indices
287
+ indices = np.random.choice(total_patches, n_patches, replace=False)
288
+
289
+ # extract patches
290
+ val_patches = self.data[indices]
291
+
292
+ # remove patches from self.patch
293
+ self.data = np.delete(self.data, indices, axis=0)
294
+
295
+ # same for targets
296
+ if self.data_targets is not None:
297
+ val_targets = self.data_targets[indices]
298
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
299
+
300
+ # clone the dataset
301
+ dataset = copy.deepcopy(self)
302
+
303
+ # reassign patches
304
+ dataset.data = val_patches
305
+
306
+ # reassign targets
307
+ if self.data_targets is not None:
308
+ dataset.data_targets = val_targets
309
+
310
+ return dataset
@@ -0,0 +1,88 @@
1
+ """In-memory prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.transformations import NormalizeModel
12
+ from .dataset_utils import reshape_array
13
+
14
+
15
+ class InMemoryPredDataset(Dataset):
16
+ """Simple prediction dataset returning images along the sample axis.
17
+
18
+ Parameters
19
+ ----------
20
+ prediction_config : InferenceConfig
21
+ Prediction configuration.
22
+ inputs : NDArray
23
+ Input data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ prediction_config: InferenceConfig,
29
+ inputs: NDArray,
30
+ ) -> None:
31
+ """Constructor.
32
+
33
+ Parameters
34
+ ----------
35
+ prediction_config : InferenceConfig
36
+ Prediction configuration.
37
+ inputs : NDArray
38
+ Input data.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If data_path is not a directory.
44
+ """
45
+ self.pred_config = prediction_config
46
+ self.input_array = inputs
47
+ self.axes = self.pred_config.axes
48
+ self.image_means = self.pred_config.image_means
49
+ self.image_stds = self.pred_config.image_stds
50
+
51
+ # Reshape data
52
+ self.data = reshape_array(self.input_array, self.axes)
53
+
54
+ # get transforms
55
+ self.patch_transform = Compose(
56
+ transform_list=[
57
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
58
+ ],
59
+ )
60
+
61
+ def __len__(self) -> int:
62
+ """
63
+ Return the length of the dataset.
64
+
65
+ Returns
66
+ -------
67
+ int
68
+ Length of the dataset.
69
+ """
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, index: int) -> NDArray:
73
+ """
74
+ Return the patch corresponding to the provided index.
75
+
76
+ Parameters
77
+ ----------
78
+ index : int
79
+ Index of the patch to return.
80
+
81
+ Returns
82
+ -------
83
+ NDArray
84
+ Transformed patch.
85
+ """
86
+ transformed_patch, _ = self.patch_transform(patch=self.data[index])
87
+
88
+ return transformed_patch
@@ -0,0 +1,129 @@
1
+ """In-memory tiled prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.tile_information import TileInformation
12
+ from ..config.transformations import NormalizeModel
13
+ from .dataset_utils import reshape_array
14
+ from .tiling import extract_tiles
15
+
16
+
17
+ class InMemoryTiledPredDataset(Dataset):
18
+ """Prediction dataset storing data in memory and returning tiles of each image.
19
+
20
+ Parameters
21
+ ----------
22
+ prediction_config : InferenceConfig
23
+ Prediction configuration.
24
+ inputs : NDArray
25
+ Input data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ prediction_config: InferenceConfig,
31
+ inputs: NDArray,
32
+ ) -> None:
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ prediction_config : InferenceConfig
38
+ Prediction configuration.
39
+ inputs : NDArray
40
+ Input data.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If data_path is not a directory.
46
+ """
47
+ if (
48
+ prediction_config.tile_size is None
49
+ or prediction_config.tile_overlap is None
50
+ ):
51
+ raise ValueError(
52
+ "Tile size and overlap must be provided to use the tiled prediction "
53
+ "dataset."
54
+ )
55
+
56
+ self.pred_config = prediction_config
57
+ self.input_array = inputs
58
+ self.axes = self.pred_config.axes
59
+ self.tile_size = prediction_config.tile_size
60
+ self.tile_overlap = prediction_config.tile_overlap
61
+ self.image_means = self.pred_config.image_means
62
+ self.image_stds = self.pred_config.image_stds
63
+
64
+ # Generate patches
65
+ self.data = self._prepare_tiles()
66
+
67
+ # get transforms
68
+ self.patch_transform = Compose(
69
+ transform_list=[
70
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
71
+ ],
72
+ )
73
+
74
+ def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
75
+ """
76
+ Iterate over data source and create an array of patches.
77
+
78
+ Returns
79
+ -------
80
+ list of tuples of NDArray and TileInformation
81
+ List of tiles and tile information.
82
+ """
83
+ # reshape array
84
+ reshaped_sample = reshape_array(self.input_array, self.axes)
85
+
86
+ # generate patches, which returns a generator
87
+ patch_generator = extract_tiles(
88
+ arr=reshaped_sample,
89
+ tile_size=self.tile_size,
90
+ overlaps=self.tile_overlap,
91
+ )
92
+ patches_list = list(patch_generator)
93
+
94
+ if len(patches_list) == 0:
95
+ raise ValueError("No tiles generated, ")
96
+
97
+ return patches_list
98
+
99
+ def __len__(self) -> int:
100
+ """
101
+ Return the length of the dataset.
102
+
103
+ Returns
104
+ -------
105
+ int
106
+ Length of the dataset.
107
+ """
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
111
+ """
112
+ Return the patch corresponding to the provided index.
113
+
114
+ Parameters
115
+ ----------
116
+ index : int
117
+ Index of the patch to return.
118
+
119
+ Returns
120
+ -------
121
+ tuple of NDArray and TileInformation
122
+ Transformed patch.
123
+ """
124
+ tile_array, tile_info = self.data[index]
125
+
126
+ # Apply transforms
127
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
128
+
129
+ return transformed_tile, tile_info