careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -4,25 +4,25 @@ from __future__ import annotations
4
4
 
5
5
  import copy
6
6
  from pathlib import Path
7
- from typing import Any, Callable, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
+ from careamics.file_io.read import read_tiff
12
13
  from careamics.transforms import Compose
13
14
 
14
- from ..config import DataConfig, InferenceConfig
15
- from ..config.tile_information import TileInformation
15
+ from ..config import DataConfig
16
16
  from ..config.transformations import NormalizeModel
17
17
  from ..utils.logging import get_logger
18
- from .dataset_utils import read_tiff, reshape_array
19
18
  from .patching.patching import (
19
+ PatchedOutput,
20
+ Stats,
20
21
  prepare_patches_supervised,
21
22
  prepare_patches_supervised_array,
22
23
  prepare_patches_unsupervised,
23
24
  prepare_patches_unsupervised_array,
24
25
  )
25
- from .patching.tiled_patching import extract_tiles
26
26
 
27
27
  logger = get_logger(__name__)
28
28
 
@@ -32,11 +32,12 @@ class InMemoryDataset(Dataset):
32
32
 
33
33
  Parameters
34
34
  ----------
35
- data_config : DataConfig
35
+ data_config : CAREamics DataConfig
36
+ (see careamics.config.data_model.DataConfig)
36
37
  Data configuration.
37
- inputs : Union[np.ndarray, List[Path]]
38
+ inputs : numpy.ndarray or list[pathlib.Path]
38
39
  Input data.
39
- input_target : Optional[Union[np.ndarray, List[Path]]], optional
40
+ input_target : numpy.ndarray or list[pathlib.Path], optional
40
41
  Target data, by default None.
41
42
  read_source_func : Callable, optional
42
43
  Read source function for custom types, by default read_tiff.
@@ -47,8 +48,8 @@ class InMemoryDataset(Dataset):
47
48
  def __init__(
48
49
  self,
49
50
  data_config: DataConfig,
50
- inputs: Union[np.ndarray, List[Path]],
51
- input_target: Optional[Union[np.ndarray, List[Path]]] = None,
51
+ inputs: Union[np.ndarray, list[Path]],
52
+ input_target: Optional[Union[np.ndarray, list[Path]]] = None,
52
53
  read_source_func: Callable = read_tiff,
53
54
  **kwargs: Any,
54
55
  ) -> None:
@@ -59,9 +60,9 @@ class InMemoryDataset(Dataset):
59
60
  ----------
60
61
  data_config : DataConfig
61
62
  Data configuration.
62
- inputs : Union[np.ndarray, List[Path]]
63
+ inputs : numpy.ndarray or list[pathlib.Path]
63
64
  Input data.
64
- input_target : Optional[Union[np.ndarray, List[Path]]], optional
65
+ input_target : numpy.ndarray or list[pathlib.Path], optional
65
66
  Target data, by default None.
66
67
  read_source_func : Callable, optional
67
68
  Read source function for custom types, by default read_tiff.
@@ -77,31 +78,56 @@ class InMemoryDataset(Dataset):
77
78
  # read function
78
79
  self.read_source_func = read_source_func
79
80
 
80
- # Generate patches
81
+ # generate patches
81
82
  supervised = self.input_targets is not None
82
- patch_data = self._prepare_patches(supervised)
83
-
84
- # Add results to members
85
- self.patches, self.patch_targets, computed_mean, computed_std = patch_data
86
-
87
- if not self.data_config.mean or not self.data_config.std:
88
- self.mean, self.std = computed_mean, computed_std
89
- logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
83
+ patches_data = self._prepare_patches(supervised)
84
+
85
+ # unpack the dataclass
86
+ self.data = patches_data.patches
87
+ self.data_targets = patches_data.targets
88
+
89
+ # set image statistics
90
+ if self.data_config.image_means is None:
91
+ self.image_stats = patches_data.image_stats
92
+ logger.info(
93
+ f"Computed dataset mean: {self.image_stats.means}, "
94
+ f"std: {self.image_stats.stds}"
95
+ )
96
+ else:
97
+ self.image_stats = Stats(
98
+ self.data_config.image_means, self.data_config.image_stds
99
+ )
90
100
 
91
- # update mean and std in configuration
92
- # the object is mutable and should then be recorded in the CAREamist obj
93
- self.data_config.set_mean_and_std(self.mean, self.std)
101
+ # set target statistics
102
+ if self.data_config.target_means is None:
103
+ self.target_stats = patches_data.target_stats
94
104
  else:
95
- self.mean, self.std = self.data_config.mean, self.data_config.std
105
+ self.target_stats = Stats(
106
+ self.data_config.target_means, self.data_config.target_stds
107
+ )
96
108
 
109
+ # update mean and std in configuration
110
+ # the object is mutable and should then be recorded in the CAREamist obj
111
+ self.data_config.set_means_and_stds(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
116
+ )
97
117
  # get transforms
98
118
  self.patch_transform = Compose(
99
- transform_list=self.data_config.transforms,
119
+ transform_list=[
120
+ NormalizeModel(
121
+ image_means=self.image_stats.means,
122
+ image_stds=self.image_stats.stds,
123
+ target_means=self.target_stats.means,
124
+ target_stds=self.target_stats.stds,
125
+ )
126
+ ]
127
+ + self.data_config.transforms,
100
128
  )
101
129
 
102
- def _prepare_patches(
103
- self, supervised: bool
104
- ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
130
+ def _prepare_patches(self, supervised: bool) -> PatchedOutput:
105
131
  """
106
132
  Iterate over data source and create an array of patches.
107
133
 
@@ -112,7 +138,7 @@ class InMemoryDataset(Dataset):
112
138
 
113
139
  Returns
114
140
  -------
115
- np.ndarray
141
+ numpy.ndarray
116
142
  Array of patches.
117
143
  """
118
144
  if supervised:
@@ -163,9 +189,9 @@ class InMemoryDataset(Dataset):
163
189
  int
164
190
  Length of the dataset.
165
191
  """
166
- return len(self.patches)
192
+ return self.data.shape[0]
167
193
 
168
- def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
194
+ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
169
195
  """
170
196
  Return the patch corresponding to the provided index.
171
197
 
@@ -176,7 +202,7 @@ class InMemoryDataset(Dataset):
176
202
 
177
203
  Returns
178
204
  -------
179
- Tuple[np.ndarray]
205
+ tuple of numpy.ndarray
180
206
  Patch.
181
207
 
182
208
  Raises
@@ -184,16 +210,16 @@ class InMemoryDataset(Dataset):
184
210
  ValueError
185
211
  If dataset mean and std are not set.
186
212
  """
187
- patch = self.patches[index]
213
+ patch = self.data[index]
188
214
 
189
215
  # if there is a target
190
- if self.patch_targets is not None:
216
+ if self.data_targets is not None:
191
217
  # get target
192
- target = self.patch_targets[index]
218
+ target = self.data_targets[index]
193
219
 
194
220
  return self.patch_transform(patch=patch, target=target)
195
221
 
196
- elif self.data_config.has_n2v_manipulate():
222
+ elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
197
223
  return self.patch_transform(patch=patch)
198
224
  else:
199
225
  raise ValueError(
@@ -201,6 +227,18 @@ class InMemoryDataset(Dataset):
201
227
  "and no N2V manipulation (no N2V training)."
202
228
  )
203
229
 
230
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
231
+ """Return training data statistics.
232
+
233
+ This does not return the target data statistics, only those of the input.
234
+
235
+ Returns
236
+ -------
237
+ tuple of list of floats
238
+ Means and standard deviations across channels of the training data.
239
+ """
240
+ return self.image_stats.get_statistics()
241
+
204
242
  def split_dataset(
205
243
  self,
206
244
  percentage: float = 0.1,
@@ -219,7 +257,7 @@ class InMemoryDataset(Dataset):
219
257
 
220
258
  Returns
221
259
  -------
222
- InMemoryDataset
260
+ CAREamics InMemoryDataset
223
261
  New dataset with the extracted patches.
224
262
 
225
263
  Raises
@@ -249,151 +287,24 @@ class InMemoryDataset(Dataset):
249
287
  indices = np.random.choice(total_patches, n_patches, replace=False)
250
288
 
251
289
  # extract patches
252
- val_patches = self.patches[indices]
290
+ val_patches = self.data[indices]
253
291
 
254
292
  # remove patches from self.patch
255
- self.patches = np.delete(self.patches, indices, axis=0)
293
+ self.data = np.delete(self.data, indices, axis=0)
256
294
 
257
295
  # same for targets
258
- if self.patch_targets is not None:
259
- val_targets = self.patch_targets[indices]
260
- self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
296
+ if self.data_targets is not None:
297
+ val_targets = self.data_targets[indices]
298
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
261
299
 
262
300
  # clone the dataset
263
301
  dataset = copy.deepcopy(self)
264
302
 
265
303
  # reassign patches
266
- dataset.patches = val_patches
304
+ dataset.data = val_patches
267
305
 
268
306
  # reassign targets
269
- if self.patch_targets is not None:
270
- dataset.patch_targets = val_targets
307
+ if self.data_targets is not None:
308
+ dataset.data_targets = val_targets
271
309
 
272
310
  return dataset
273
-
274
-
275
- class InMemoryPredictionDataset(Dataset):
276
- """
277
- Dataset storing data in memory and allowing generating patches from it.
278
-
279
- Parameters
280
- ----------
281
- prediction_config : InferenceConfig
282
- Prediction configuration.
283
- inputs : np.ndarray
284
- Input data.
285
- data_target : Optional[np.ndarray], optional
286
- Target data, by default None.
287
- read_source_func : Optional[Callable], optional
288
- Read source function for custom types, by default read_tiff.
289
- """
290
-
291
- def __init__(
292
- self,
293
- prediction_config: InferenceConfig,
294
- inputs: np.ndarray,
295
- data_target: Optional[np.ndarray] = None,
296
- read_source_func: Optional[Callable] = read_tiff,
297
- ) -> None:
298
- """Constructor.
299
-
300
- Parameters
301
- ----------
302
- prediction_config : InferenceConfig
303
- Prediction configuration.
304
- inputs : np.ndarray
305
- Input data.
306
- data_target : Optional[np.ndarray], optional
307
- Target data, by default None.
308
- read_source_func : Optional[Callable], optional
309
- Read source function for custom types, by default read_tiff.
310
-
311
- Raises
312
- ------
313
- ValueError
314
- If data_path is not a directory.
315
- """
316
- self.pred_config = prediction_config
317
- self.input_array = inputs
318
- self.axes = self.pred_config.axes
319
- self.tile_size = self.pred_config.tile_size
320
- self.tile_overlap = self.pred_config.tile_overlap
321
- self.mean = self.pred_config.mean
322
- self.std = self.pred_config.std
323
- self.data_target = data_target
324
-
325
- # tiling only if both tile size and overlap are provided
326
- self.tiling = self.tile_size is not None and self.tile_overlap is not None
327
-
328
- # read function
329
- self.read_source_func = read_source_func
330
-
331
- # Generate patches
332
- self.data = self._prepare_tiles()
333
- self.mean, self.std = self.pred_config.mean, self.pred_config.std
334
-
335
- # get transforms
336
- self.patch_transform = Compose(
337
- transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
338
- )
339
-
340
- def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
341
- """
342
- Iterate over data source and create an array of patches.
343
-
344
- Returns
345
- -------
346
- List[XArrayTile]
347
- List of tiles.
348
- """
349
- # reshape array
350
- reshaped_sample = reshape_array(self.input_array, self.axes)
351
-
352
- if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
353
- # generate patches, which returns a generator
354
- patch_generator = extract_tiles(
355
- arr=reshaped_sample,
356
- tile_size=self.tile_size,
357
- overlaps=self.tile_overlap,
358
- )
359
- patches_list = list(patch_generator)
360
-
361
- if len(patches_list) == 0:
362
- raise ValueError("No tiles generated, ")
363
-
364
- return patches_list
365
- else:
366
- array_shape = reshaped_sample.squeeze().shape
367
- return [(reshaped_sample, TileInformation(array_shape=array_shape))]
368
-
369
- def __len__(self) -> int:
370
- """
371
- Return the length of the dataset.
372
-
373
- Returns
374
- -------
375
- int
376
- Length of the dataset.
377
- """
378
- return len(self.data)
379
-
380
- def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
381
- """
382
- Return the patch corresponding to the provided index.
383
-
384
- Parameters
385
- ----------
386
- index : int
387
- Index of the patch to return.
388
-
389
- Returns
390
- -------
391
- Tuple[np.ndarray, TileInformation]
392
- Transformed patch.
393
- """
394
- tile_array, tile_info = self.data[index]
395
-
396
- # Apply transforms
397
- transformed_tile, _ = self.patch_transform(patch=tile_array)
398
-
399
- return transformed_tile, tile_info
@@ -0,0 +1,88 @@
1
+ """In-memory prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.transformations import NormalizeModel
12
+ from .dataset_utils import reshape_array
13
+
14
+
15
+ class InMemoryPredDataset(Dataset):
16
+ """Simple prediction dataset returning images along the sample axis.
17
+
18
+ Parameters
19
+ ----------
20
+ prediction_config : InferenceConfig
21
+ Prediction configuration.
22
+ inputs : NDArray
23
+ Input data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ prediction_config: InferenceConfig,
29
+ inputs: NDArray,
30
+ ) -> None:
31
+ """Constructor.
32
+
33
+ Parameters
34
+ ----------
35
+ prediction_config : InferenceConfig
36
+ Prediction configuration.
37
+ inputs : NDArray
38
+ Input data.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If data_path is not a directory.
44
+ """
45
+ self.pred_config = prediction_config
46
+ self.input_array = inputs
47
+ self.axes = self.pred_config.axes
48
+ self.image_means = self.pred_config.image_means
49
+ self.image_stds = self.pred_config.image_stds
50
+
51
+ # Reshape data
52
+ self.data = reshape_array(self.input_array, self.axes)
53
+
54
+ # get transforms
55
+ self.patch_transform = Compose(
56
+ transform_list=[
57
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
58
+ ],
59
+ )
60
+
61
+ def __len__(self) -> int:
62
+ """
63
+ Return the length of the dataset.
64
+
65
+ Returns
66
+ -------
67
+ int
68
+ Length of the dataset.
69
+ """
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, index: int) -> NDArray:
73
+ """
74
+ Return the patch corresponding to the provided index.
75
+
76
+ Parameters
77
+ ----------
78
+ index : int
79
+ Index of the patch to return.
80
+
81
+ Returns
82
+ -------
83
+ NDArray
84
+ Transformed patch.
85
+ """
86
+ transformed_patch, _ = self.patch_transform(patch=self.data[index])
87
+
88
+ return transformed_patch
@@ -0,0 +1,129 @@
1
+ """In-memory tiled prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.tile_information import TileInformation
12
+ from ..config.transformations import NormalizeModel
13
+ from .dataset_utils import reshape_array
14
+ from .tiling import extract_tiles
15
+
16
+
17
+ class InMemoryTiledPredDataset(Dataset):
18
+ """Prediction dataset storing data in memory and returning tiles of each image.
19
+
20
+ Parameters
21
+ ----------
22
+ prediction_config : InferenceConfig
23
+ Prediction configuration.
24
+ inputs : NDArray
25
+ Input data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ prediction_config: InferenceConfig,
31
+ inputs: NDArray,
32
+ ) -> None:
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ prediction_config : InferenceConfig
38
+ Prediction configuration.
39
+ inputs : NDArray
40
+ Input data.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If data_path is not a directory.
46
+ """
47
+ if (
48
+ prediction_config.tile_size is None
49
+ or prediction_config.tile_overlap is None
50
+ ):
51
+ raise ValueError(
52
+ "Tile size and overlap must be provided to use the tiled prediction "
53
+ "dataset."
54
+ )
55
+
56
+ self.pred_config = prediction_config
57
+ self.input_array = inputs
58
+ self.axes = self.pred_config.axes
59
+ self.tile_size = prediction_config.tile_size
60
+ self.tile_overlap = prediction_config.tile_overlap
61
+ self.image_means = self.pred_config.image_means
62
+ self.image_stds = self.pred_config.image_stds
63
+
64
+ # Generate patches
65
+ self.data = self._prepare_tiles()
66
+
67
+ # get transforms
68
+ self.patch_transform = Compose(
69
+ transform_list=[
70
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
71
+ ],
72
+ )
73
+
74
+ def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
75
+ """
76
+ Iterate over data source and create an array of patches.
77
+
78
+ Returns
79
+ -------
80
+ list of tuples of NDArray and TileInformation
81
+ List of tiles and tile information.
82
+ """
83
+ # reshape array
84
+ reshaped_sample = reshape_array(self.input_array, self.axes)
85
+
86
+ # generate patches, which returns a generator
87
+ patch_generator = extract_tiles(
88
+ arr=reshaped_sample,
89
+ tile_size=self.tile_size,
90
+ overlaps=self.tile_overlap,
91
+ )
92
+ patches_list = list(patch_generator)
93
+
94
+ if len(patches_list) == 0:
95
+ raise ValueError("No tiles generated, ")
96
+
97
+ return patches_list
98
+
99
+ def __len__(self) -> int:
100
+ """
101
+ Return the length of the dataset.
102
+
103
+ Returns
104
+ -------
105
+ int
106
+ Length of the dataset.
107
+ """
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
111
+ """
112
+ Return the patch corresponding to the provided index.
113
+
114
+ Parameters
115
+ ----------
116
+ index : int
117
+ Index of the patch to return.
118
+
119
+ Returns
120
+ -------
121
+ tuple of NDArray and TileInformation
122
+ Transformed patch.
123
+ """
124
+ tile_array, tile_info = self.data[index]
125
+
126
+ # Apply transforms
127
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
128
+
129
+ return transformed_tile, tile_info