careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -3,86 +3,27 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import copy
6
+ from collections.abc import Generator
6
7
  from pathlib import Path
7
- from typing import Any, Callable, Generator, List, Optional, Tuple, Union
8
+ from typing import Callable, Optional
8
9
 
9
10
  import numpy as np
10
- from torch.utils.data import IterableDataset, get_worker_info
11
+ from torch.utils.data import IterableDataset
11
12
 
13
+ from careamics.config import DataConfig
14
+ from careamics.config.transformations import NormalizeModel
15
+ from careamics.file_io.read import read_tiff
12
16
  from careamics.transforms import Compose
13
17
 
14
- from ..config import DataConfig, InferenceConfig
15
- from ..config.tile_information import TileInformation
16
- from ..config.transformations import NormalizeModel
17
18
  from ..utils.logging import get_logger
18
- from .dataset_utils import read_tiff, reshape_array
19
+ from .dataset_utils import iterate_over_files
20
+ from .dataset_utils.running_stats import WelfordStatistics
21
+ from .patching.patching import Stats
19
22
  from .patching.random_patching import extract_patches_random
20
- from .patching.tiled_patching import extract_tiles
21
23
 
22
24
  logger = get_logger(__name__)
23
25
 
24
26
 
25
- def _iterate_over_files(
26
- data_config: Union[DataConfig, InferenceConfig],
27
- data_files: List[Path],
28
- target_files: Optional[List[Path]] = None,
29
- read_source_func: Callable = read_tiff,
30
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
31
- """
32
- Iterate over data source and yield whole image.
33
-
34
- Parameters
35
- ----------
36
- data_config : Union[DataConfig, InferenceConfig]
37
- Data configuration.
38
- data_files : List[Path]
39
- List of data files.
40
- target_files : Optional[List[Path]]
41
- List of target files, by default None.
42
- read_source_func : Optional[Callable]
43
- Function to read the source, by default read_tiff.
44
-
45
- Yields
46
- ------
47
- np.ndarray
48
- Image.
49
- """
50
- # When num_workers > 0, each worker process will have a different copy of the
51
- # dataset object
52
- # Configuring each copy independently to avoid having duplicate data returned
53
- # from the workers
54
- worker_info = get_worker_info()
55
- worker_id = worker_info.id if worker_info is not None else 0
56
- num_workers = worker_info.num_workers if worker_info is not None else 1
57
-
58
- # iterate over the files
59
- for i, filename in enumerate(data_files):
60
- # retrieve file corresponding to the worker id
61
- if i % num_workers == worker_id:
62
- try:
63
- # read data
64
- sample = read_source_func(filename, data_config.axes)
65
-
66
- # read target, if available
67
- if target_files is not None:
68
- if filename.name != target_files[i].name:
69
- raise ValueError(
70
- f"File {filename} does not match target file "
71
- f"{target_files[i]}. Have you passed sorted "
72
- f"arrays?"
73
- )
74
-
75
- # read target
76
- target = read_source_func(target_files[i], data_config.axes)
77
-
78
- yield sample, target
79
- else:
80
- yield sample, None
81
-
82
- except Exception as e:
83
- logger.error(f"Error reading file {filename}: {e}")
84
-
85
-
86
27
  class PathIterableDataset(IterableDataset):
87
28
  """
88
29
  Dataset allowing extracting patches w/o loading whole data into memory.
@@ -91,38 +32,26 @@ class PathIterableDataset(IterableDataset):
91
32
  ----------
92
33
  data_config : DataConfig
93
34
  Data configuration.
94
- src_files : List[Path]
35
+ src_files : list of pathlib.Path
95
36
  List of data files.
96
- target_files : Optional[List[Path]], optional
37
+ target_files : list of pathlib.Path, optional
97
38
  Optional list of target files, by default None.
98
39
  read_source_func : Callable, optional
99
40
  Read source function for custom types, by default read_tiff.
100
41
 
101
42
  Attributes
102
43
  ----------
103
- data_path : List[Path]
44
+ data_path : list of pathlib.Path
104
45
  Path to the data, must be a directory.
105
46
  axes : str
106
47
  Description of axes in format STCZYX.
107
- patch_extraction_method : Union[ExtractionStrategies, None]
108
- Patch extraction strategy, as defined in extraction_strategy.
109
- patch_size : Optional[Union[List[int], Tuple[int]]], optional
110
- Size of the patches in each dimension, by default None.
111
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
112
- Overlap of the patches in each dimension, by default None.
113
- mean : Optional[float], optional
114
- Expected mean of the dataset, by default None.
115
- std : Optional[float], optional
116
- Expected standard deviation of the dataset, by default None.
117
- patch_transform : Optional[Callable], optional
118
- Patch transform callable, by default None.
119
48
  """
120
49
 
121
50
  def __init__(
122
51
  self,
123
52
  data_config: DataConfig,
124
- src_files: List[Path],
125
- target_files: Optional[List[Path]] = None,
53
+ src_files: list[Path],
54
+ target_files: Optional[list[Path]] = None,
126
55
  read_source_func: Callable = read_tiff,
127
56
  ) -> None:
128
57
  """Constructors.
@@ -131,9 +60,9 @@ class PathIterableDataset(IterableDataset):
131
60
  ----------
132
61
  data_config : DataConfig
133
62
  Data configuration.
134
- src_files : List[Path]
63
+ src_files : list[Path]
135
64
  List of data files.
136
- target_files : Optional[List[Path]], optional
65
+ target_files : list[Path] or None, optional
137
66
  Optional list of target files, by default None.
138
67
  read_source_func : Callable, optional
139
68
  Read source function for custom types, by default read_tiff.
@@ -141,55 +70,99 @@ class PathIterableDataset(IterableDataset):
141
70
  self.data_config = data_config
142
71
  self.data_files = src_files
143
72
  self.target_files = target_files
144
- self.data_config = data_config
145
73
  self.read_source_func = read_source_func
146
74
 
147
75
  # compute mean and std over the dataset
148
- if not data_config.mean or not data_config.std:
149
- self.mean, self.std = self._calculate_mean_and_std()
76
+ # only checking the image_mean because the DataConfig class ensures that
77
+ # if image_mean is provided, image_std is also provided
78
+ if not self.data_config.image_means:
79
+ self.image_stats, self.target_stats = self._calculate_mean_and_std()
80
+ logger.info(
81
+ f"Computed dataset mean: {self.image_stats.means},"
82
+ f"std: {self.image_stats.stds}"
83
+ )
84
+
85
+ # update the mean in the config
86
+ self.data_config.set_means_and_stds(
87
+ image_means=self.image_stats.means,
88
+ image_stds=self.image_stats.stds,
89
+ target_means=(
90
+ list(self.target_stats.means)
91
+ if self.target_stats.means is not None
92
+ else None
93
+ ),
94
+ target_stds=(
95
+ list(self.target_stats.stds)
96
+ if self.target_stats.stds is not None
97
+ else None
98
+ ),
99
+ )
150
100
 
151
- # update mean and std in configuration
152
- # the object is mutable and should then be recorded in the CAREamist
153
- data_config.set_mean_and_std(self.mean, self.std)
154
101
  else:
155
- self.mean = data_config.mean
156
- self.std = data_config.std
102
+ # if mean and std are provided in the config, use them
103
+ self.image_stats, self.target_stats = (
104
+ Stats(self.data_config.image_means, self.data_config.image_stds),
105
+ Stats(self.data_config.target_means, self.data_config.target_stds),
106
+ )
157
107
 
158
- # get transforms
159
- self.patch_transform = Compose(transform_list=data_config.transforms)
108
+ # create transform composed of normalization and other transforms
109
+ self.patch_transform = Compose(
110
+ transform_list=[
111
+ NormalizeModel(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
116
+ )
117
+ ]
118
+ + data_config.transforms
119
+ )
160
120
 
161
- def _calculate_mean_and_std(self) -> Tuple[float, float]:
121
+ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
162
122
  """
163
123
  Calculate mean and std of the dataset.
164
124
 
165
125
  Returns
166
126
  -------
167
- Tuple[float, float]
168
- Tuple containing mean and standard deviation.
127
+ tuple of Stats and optional Stats
128
+ Data classes containing the image and target statistics.
169
129
  """
170
- means, stds = 0, 0
171
130
  num_samples = 0
131
+ image_stats = WelfordStatistics()
132
+ if self.target_files is not None:
133
+ target_stats = WelfordStatistics()
172
134
 
173
- for sample, _ in _iterate_over_files(
135
+ for sample, target in iterate_over_files(
174
136
  self.data_config, self.data_files, self.target_files, self.read_source_func
175
137
  ):
176
- means += sample.mean()
177
- stds += sample.std()
138
+ # update the image statistics
139
+ image_stats.update(sample, num_samples)
140
+
141
+ # update the target statistics if target is available
142
+ if target is not None:
143
+ target_stats.update(target, num_samples)
144
+
178
145
  num_samples += 1
179
146
 
180
147
  if num_samples == 0:
181
148
  raise ValueError("No samples found in the dataset.")
182
149
 
183
- result_mean = means / num_samples
184
- result_std = stds / num_samples
150
+ # Average the means and stds per sample
151
+ image_means, image_stds = image_stats.finalize()
185
152
 
186
- logger.info(f"Calculated mean and std for {num_samples} images")
187
- logger.info(f"Mean: {result_mean}, std: {result_std}")
188
- return result_mean, result_std
153
+ if target is not None:
154
+ target_means, target_stds = target_stats.finalize()
155
+
156
+ return (
157
+ Stats(image_means, image_stds),
158
+ Stats(np.array(target_means), np.array(target_stds)),
159
+ )
160
+ else:
161
+ return Stats(image_means, image_stds), Stats(None, None)
189
162
 
190
163
  def __iter__(
191
164
  self,
192
- ) -> Generator[Tuple[np.ndarray, ...], None, None]:
165
+ ) -> Generator[tuple[np.ndarray, ...], None, None]:
193
166
  """
194
167
  Iterate over data source and yield single patch.
195
168
 
@@ -199,24 +172,17 @@ class PathIterableDataset(IterableDataset):
199
172
  Single patch.
200
173
  """
201
174
  assert (
202
- self.mean is not None and self.std is not None
175
+ self.image_stats.means is not None and self.image_stats.stds is not None
203
176
  ), "Mean and std must be provided"
204
177
 
205
178
  # iterate over files
206
- for sample_input, sample_target in _iterate_over_files(
179
+ for sample_input, sample_target in iterate_over_files(
207
180
  self.data_config, self.data_files, self.target_files, self.read_source_func
208
181
  ):
209
- reshaped_sample = reshape_array(sample_input, self.data_config.axes)
210
- reshaped_target = (
211
- None
212
- if sample_target is None
213
- else reshape_array(sample_target, self.data_config.axes)
214
- )
215
-
216
182
  patches = extract_patches_random(
217
- arr=reshaped_sample,
183
+ arr=sample_input,
218
184
  patch_size=self.data_config.patch_size,
219
- target=reshaped_target,
185
+ target=sample_target,
220
186
  )
221
187
 
222
188
  # iterate over patches
@@ -229,6 +195,16 @@ class PathIterableDataset(IterableDataset):
229
195
  target=patch_data[1],
230
196
  )
231
197
 
198
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
199
+ """Return training data statistics.
200
+
201
+ Returns
202
+ -------
203
+ tuple of list of floats
204
+ Means and standard deviations across channels of the training data.
205
+ """
206
+ return self.image_stats.get_statistics()
207
+
232
208
  def get_number_of_files(self) -> int:
233
209
  """
234
210
  Return the number of files in the dataset.
@@ -317,132 +293,3 @@ class PathIterableDataset(IterableDataset):
317
293
  dataset.target_files = val_target_files
318
294
 
319
295
  return dataset
320
-
321
-
322
- class IterablePredictionDataset(IterableDataset):
323
- """
324
- Prediction dataset.
325
-
326
- Parameters
327
- ----------
328
- prediction_config : InferenceConfig
329
- Inference configuration.
330
- src_files : List[Path]
331
- List of data files.
332
- read_source_func : Callable, optional
333
- Read source function for custom types, by default read_tiff.
334
- **kwargs : Any
335
- Additional keyword arguments, unused.
336
-
337
- Attributes
338
- ----------
339
- data_path : Union[str, Path]
340
- Path to the data, must be a directory.
341
- axes : str
342
- Description of axes in format STCZYX.
343
- mean : Optional[float], optional
344
- Expected mean of the dataset, by default None.
345
- std : Optional[float], optional
346
- Expected standard deviation of the dataset, by default None.
347
- patch_transform : Optional[Callable], optional
348
- Patch transform callable, by default None.
349
- """
350
-
351
- def __init__(
352
- self,
353
- prediction_config: InferenceConfig,
354
- src_files: List[Path],
355
- read_source_func: Callable = read_tiff,
356
- **kwargs: Any,
357
- ) -> None:
358
- """Constructor.
359
-
360
- Parameters
361
- ----------
362
- prediction_config : InferenceConfig
363
- Inference configuration.
364
- src_files : List[Path]
365
- List of data files.
366
- read_source_func : Callable, optional
367
- Read source function for custom types, by default read_tiff.
368
- **kwargs : Any
369
- Additional keyword arguments, unused.
370
-
371
- Raises
372
- ------
373
- ValueError
374
- If mean and std are not provided in the inference configuration.
375
- """
376
- self.prediction_config = prediction_config
377
- self.data_files = src_files
378
- self.axes = prediction_config.axes
379
- self.tile_size = self.prediction_config.tile_size
380
- self.tile_overlap = self.prediction_config.tile_overlap
381
- self.read_source_func = read_source_func
382
-
383
- # tile only if both tile size and overlaps are provided
384
- self.tile = self.tile_size is not None and self.tile_overlap is not None
385
-
386
- # check mean and std and create normalize transform
387
- if self.prediction_config.mean is None or self.prediction_config.std is None:
388
- raise ValueError("Mean and std must be provided for prediction.")
389
- else:
390
- self.mean = self.prediction_config.mean
391
- self.std = self.prediction_config.std
392
-
393
- # instantiate normalize transform
394
- self.patch_transform = Compose(
395
- transform_list=[
396
- NormalizeModel(
397
- mean=prediction_config.mean, std=prediction_config.std
398
- )
399
- ],
400
- )
401
-
402
- def __iter__(
403
- self,
404
- ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
405
- """
406
- Iterate over data source and yield single patch.
407
-
408
- Yields
409
- ------
410
- np.ndarray
411
- Single patch.
412
- """
413
- assert (
414
- self.mean is not None and self.std is not None
415
- ), "Mean and std must be provided"
416
-
417
- for sample, _ in _iterate_over_files(
418
- self.prediction_config,
419
- self.data_files,
420
- read_source_func=self.read_source_func,
421
- ):
422
- # reshape array
423
- reshaped_sample = reshape_array(sample, self.axes)
424
-
425
- if (
426
- self.tile
427
- and self.tile_size is not None
428
- and self.tile_overlap is not None
429
- ):
430
- # generate patches, return a generator
431
- patch_gen = extract_tiles(
432
- arr=reshaped_sample,
433
- tile_size=self.tile_size,
434
- overlaps=self.tile_overlap,
435
- )
436
- else:
437
- # just wrap the sample in a generator with default tiling info
438
- array_shape = reshaped_sample.squeeze().shape
439
- patch_gen = (
440
- (reshaped_sample, TileInformation(array_shape=array_shape))
441
- for _ in range(1)
442
- )
443
-
444
- # apply transform to patches
445
- for patch_array, tile_info in patch_gen:
446
- transformed_patch, _ = self.patch_transform(patch=patch_array)
447
-
448
- yield transformed_patch, tile_info
@@ -0,0 +1,122 @@
1
+ """Iterable prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.file_io.read import read_tiff
12
+ from careamics.transforms import Compose
13
+
14
+ from ..config import InferenceConfig
15
+ from ..config.transformations import NormalizeModel
16
+ from .dataset_utils import iterate_over_files
17
+
18
+
19
+ class IterablePredDataset(IterableDataset):
20
+ """Simple iterable prediction dataset.
21
+
22
+ Parameters
23
+ ----------
24
+ prediction_config : InferenceConfig
25
+ Inference configuration.
26
+ src_files : List[Path]
27
+ List of data files.
28
+ read_source_func : Callable, optional
29
+ Read source function for custom types, by default read_tiff.
30
+ **kwargs : Any
31
+ Additional keyword arguments, unused.
32
+
33
+ Attributes
34
+ ----------
35
+ data_path : Union[str, Path]
36
+ Path to the data, must be a directory.
37
+ axes : str
38
+ Description of axes in format STCZYX.
39
+ mean : Optional[float], optional
40
+ Expected mean of the dataset, by default None.
41
+ std : Optional[float], optional
42
+ Expected standard deviation of the dataset, by default None.
43
+ patch_transform : Optional[Callable], optional
44
+ Patch transform callable, by default None.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ prediction_config: InferenceConfig,
50
+ src_files: list[Path],
51
+ read_source_func: Callable = read_tiff,
52
+ **kwargs: Any,
53
+ ) -> None:
54
+ """Constructor.
55
+
56
+ Parameters
57
+ ----------
58
+ prediction_config : InferenceConfig
59
+ Inference configuration.
60
+ src_files : list of pathlib.Path
61
+ List of data files.
62
+ read_source_func : Callable, optional
63
+ Read source function for custom types, by default read_tiff.
64
+ **kwargs : Any
65
+ Additional keyword arguments, unused.
66
+
67
+ Raises
68
+ ------
69
+ ValueError
70
+ If mean and std are not provided in the inference configuration.
71
+ """
72
+ self.prediction_config = prediction_config
73
+ self.data_files = src_files
74
+ self.axes = prediction_config.axes
75
+ self.read_source_func = read_source_func
76
+
77
+ # check mean and std and create normalize transform
78
+ if (
79
+ self.prediction_config.image_means is None
80
+ or self.prediction_config.image_stds is None
81
+ ):
82
+ raise ValueError("Mean and std must be provided for prediction.")
83
+ else:
84
+ self.image_means = self.prediction_config.image_means
85
+ self.image_stds = self.prediction_config.image_stds
86
+
87
+ # instantiate normalize transform
88
+ self.patch_transform = Compose(
89
+ transform_list=[
90
+ NormalizeModel(
91
+ image_means=self.image_means,
92
+ image_stds=self.image_stds,
93
+ )
94
+ ],
95
+ )
96
+
97
+ def __iter__(
98
+ self,
99
+ ) -> Generator[NDArray, None, None]:
100
+ """
101
+ Iterate over data source and yield single patch.
102
+
103
+ Yields
104
+ ------
105
+ NDArray
106
+ Single patch.
107
+ """
108
+ assert (
109
+ self.image_means is not None and self.image_stds is not None
110
+ ), "Mean and std must be provided"
111
+
112
+ for sample, _ in iterate_over_files(
113
+ self.prediction_config,
114
+ self.data_files,
115
+ read_source_func=self.read_source_func,
116
+ ):
117
+ # sample has S dimension
118
+ for i in range(sample.shape[0]):
119
+
120
+ transformed_sample, _ = self.patch_transform(patch=sample[i])
121
+
122
+ yield transformed_sample