careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -3,86 +3,29 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import copy
6
+ from collections.abc import Generator
6
7
  from pathlib import Path
7
- from typing import Any, Callable, Generator, List, Optional, Tuple, Union
8
+ from typing import Callable, Optional
8
9
 
9
10
  import numpy as np
10
- from torch.utils.data import IterableDataset, get_worker_info
11
+ from torch.utils.data import IterableDataset
11
12
 
13
+ from careamics.config import DataConfig
14
+ from careamics.config.transformations import NormalizeModel
12
15
  from careamics.transforms import Compose
13
16
 
14
- from ..config import DataConfig, InferenceConfig
15
- from ..config.tile_information import TileInformation
16
- from ..config.transformations import NormalizeModel
17
17
  from ..utils.logging import get_logger
18
- from .dataset_utils import read_tiff, reshape_array
18
+ from .dataset_utils import (
19
+ iterate_over_files,
20
+ read_tiff,
21
+ )
22
+ from .dataset_utils.running_stats import WelfordStatistics
23
+ from .patching.patching import Stats, StatsOutput
19
24
  from .patching.random_patching import extract_patches_random
20
- from .patching.tiled_patching import extract_tiles
21
25
 
22
26
  logger = get_logger(__name__)
23
27
 
24
28
 
25
- def _iterate_over_files(
26
- data_config: Union[DataConfig, InferenceConfig],
27
- data_files: List[Path],
28
- target_files: Optional[List[Path]] = None,
29
- read_source_func: Callable = read_tiff,
30
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
31
- """
32
- Iterate over data source and yield whole image.
33
-
34
- Parameters
35
- ----------
36
- data_config : Union[DataConfig, InferenceConfig]
37
- Data configuration.
38
- data_files : List[Path]
39
- List of data files.
40
- target_files : Optional[List[Path]]
41
- List of target files, by default None.
42
- read_source_func : Optional[Callable]
43
- Function to read the source, by default read_tiff.
44
-
45
- Yields
46
- ------
47
- np.ndarray
48
- Image.
49
- """
50
- # When num_workers > 0, each worker process will have a different copy of the
51
- # dataset object
52
- # Configuring each copy independently to avoid having duplicate data returned
53
- # from the workers
54
- worker_info = get_worker_info()
55
- worker_id = worker_info.id if worker_info is not None else 0
56
- num_workers = worker_info.num_workers if worker_info is not None else 1
57
-
58
- # iterate over the files
59
- for i, filename in enumerate(data_files):
60
- # retrieve file corresponding to the worker id
61
- if i % num_workers == worker_id:
62
- try:
63
- # read data
64
- sample = read_source_func(filename, data_config.axes)
65
-
66
- # read target, if available
67
- if target_files is not None:
68
- if filename.name != target_files[i].name:
69
- raise ValueError(
70
- f"File {filename} does not match target file "
71
- f"{target_files[i]}. Have you passed sorted "
72
- f"arrays?"
73
- )
74
-
75
- # read target
76
- target = read_source_func(target_files[i], data_config.axes)
77
-
78
- yield sample, target
79
- else:
80
- yield sample, None
81
-
82
- except Exception as e:
83
- logger.error(f"Error reading file {filename}: {e}")
84
-
85
-
86
29
  class PathIterableDataset(IterableDataset):
87
30
  """
88
31
  Dataset allowing extracting patches w/o loading whole data into memory.
@@ -91,38 +34,26 @@ class PathIterableDataset(IterableDataset):
91
34
  ----------
92
35
  data_config : DataConfig
93
36
  Data configuration.
94
- src_files : List[Path]
37
+ src_files : list of pathlib.Path
95
38
  List of data files.
96
- target_files : Optional[List[Path]], optional
39
+ target_files : list of pathlib.Path, optional
97
40
  Optional list of target files, by default None.
98
41
  read_source_func : Callable, optional
99
42
  Read source function for custom types, by default read_tiff.
100
43
 
101
44
  Attributes
102
45
  ----------
103
- data_path : List[Path]
46
+ data_path : list of pathlib.Path
104
47
  Path to the data, must be a directory.
105
48
  axes : str
106
49
  Description of axes in format STCZYX.
107
- patch_extraction_method : Union[ExtractionStrategies, None]
108
- Patch extraction strategy, as defined in extraction_strategy.
109
- patch_size : Optional[Union[List[int], Tuple[int]]], optional
110
- Size of the patches in each dimension, by default None.
111
- patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
112
- Overlap of the patches in each dimension, by default None.
113
- mean : Optional[float], optional
114
- Expected mean of the dataset, by default None.
115
- std : Optional[float], optional
116
- Expected standard deviation of the dataset, by default None.
117
- patch_transform : Optional[Callable], optional
118
- Patch transform callable, by default None.
119
50
  """
120
51
 
121
52
  def __init__(
122
53
  self,
123
54
  data_config: DataConfig,
124
- src_files: List[Path],
125
- target_files: Optional[List[Path]] = None,
55
+ src_files: list[Path],
56
+ target_files: Optional[list[Path]] = None,
126
57
  read_source_func: Callable = read_tiff,
127
58
  ) -> None:
128
59
  """Constructors.
@@ -131,9 +62,9 @@ class PathIterableDataset(IterableDataset):
131
62
  ----------
132
63
  data_config : DataConfig
133
64
  Data configuration.
134
- src_files : List[Path]
65
+ src_files : list[Path]
135
66
  List of data files.
136
- target_files : Optional[List[Path]], optional
67
+ target_files : list[Path] or None, optional
137
68
  Optional list of target files, by default None.
138
69
  read_source_func : Callable, optional
139
70
  Read source function for custom types, by default read_tiff.
@@ -141,55 +72,102 @@ class PathIterableDataset(IterableDataset):
141
72
  self.data_config = data_config
142
73
  self.data_files = src_files
143
74
  self.target_files = target_files
144
- self.data_config = data_config
145
75
  self.read_source_func = read_source_func
146
76
 
147
77
  # compute mean and std over the dataset
148
- if not data_config.mean or not data_config.std:
149
- self.mean, self.std = self._calculate_mean_and_std()
78
+ # only checking the image_mean because the DataConfig class ensures that
79
+ # if image_mean is provided, image_std is also provided
80
+ if not self.data_config.image_means:
81
+ self.data_stats = self._calculate_mean_and_std()
82
+ logger.info(
83
+ f"Computed dataset mean: {self.data_stats.image_stats.means},"
84
+ f"std: {self.data_stats.image_stats.stds}"
85
+ )
86
+
87
+ # update the mean in the config
88
+ self.data_config.set_mean_and_std(
89
+ image_means=self.data_stats.image_stats.means,
90
+ image_stds=self.data_stats.image_stats.stds,
91
+ target_means=(
92
+ list(self.data_stats.target_stats.means)
93
+ if self.data_stats.target_stats.means is not None
94
+ else None
95
+ ),
96
+ target_stds=(
97
+ list(self.data_stats.target_stats.stds)
98
+ if self.data_stats.target_stats.stds is not None
99
+ else None
100
+ ),
101
+ )
150
102
 
151
- # update mean and std in configuration
152
- # the object is mutable and should then be recorded in the CAREamist
153
- data_config.set_mean_and_std(self.mean, self.std)
154
103
  else:
155
- self.mean = data_config.mean
156
- self.std = data_config.std
104
+ # if mean and std are provided in the config, use them
105
+ self.data_stats = StatsOutput(
106
+ Stats(self.data_config.image_means, self.data_config.image_stds),
107
+ Stats(self.data_config.target_means, self.data_config.target_stds),
108
+ )
157
109
 
158
- # get transforms
159
- self.patch_transform = Compose(transform_list=data_config.transforms)
110
+ # create transform composed of normalization and other transforms
111
+ self.patch_transform = Compose(
112
+ transform_list=[
113
+ NormalizeModel(
114
+ image_means=self.data_stats.image_stats.means,
115
+ image_stds=self.data_stats.image_stats.stds,
116
+ target_means=self.data_stats.target_stats.means,
117
+ target_stds=self.data_stats.target_stats.stds,
118
+ )
119
+ ]
120
+ + data_config.transforms
121
+ )
160
122
 
161
- def _calculate_mean_and_std(self) -> Tuple[float, float]:
123
+ def _calculate_mean_and_std(self) -> StatsOutput:
162
124
  """
163
125
  Calculate mean and std of the dataset.
164
126
 
165
127
  Returns
166
128
  -------
167
- Tuple[float, float]
168
- Tuple containing mean and standard deviation.
129
+ PatchedOutput
130
+ Data class containing the image statistics.
169
131
  """
170
- means, stds = 0, 0
171
132
  num_samples = 0
133
+ image_stats = WelfordStatistics()
134
+ if self.target_files is not None:
135
+ target_stats = WelfordStatistics()
172
136
 
173
- for sample, _ in _iterate_over_files(
137
+ for sample, target in iterate_over_files(
174
138
  self.data_config, self.data_files, self.target_files, self.read_source_func
175
139
  ):
176
- means += sample.mean()
177
- stds += sample.std()
140
+ # update the image statistics
141
+ image_stats.update(sample, num_samples)
142
+
143
+ # update the target statistics if target is available
144
+ if target is not None:
145
+ target_stats.update(target, num_samples)
146
+
178
147
  num_samples += 1
179
148
 
180
149
  if num_samples == 0:
181
150
  raise ValueError("No samples found in the dataset.")
182
151
 
183
- result_mean = means / num_samples
184
- result_std = stds / num_samples
152
+ # Average the means and stds per sample
153
+ image_means, image_stds = image_stats.finalize()
154
+
155
+ if target is not None:
156
+ target_means, target_stds = target_stats.finalize()
185
157
 
186
158
  logger.info(f"Calculated mean and std for {num_samples} images")
187
- logger.info(f"Mean: {result_mean}, std: {result_std}")
188
- return result_mean, result_std
159
+ logger.info(f"Mean: {image_means}, std: {image_stds}")
160
+ return StatsOutput(
161
+ Stats(image_means, image_stds),
162
+ Stats(
163
+ np.array(target_means) if target is not None else None,
164
+ np.array(target_stds) if target is not None else None,
165
+ ),
166
+ )
189
167
 
190
168
  def __iter__(
191
169
  self,
192
- ) -> Generator[Tuple[np.ndarray, ...], None, None]:
170
+ ) -> Generator[tuple[np.ndarray, ...], None, None]:
193
171
  """
194
172
  Iterate over data source and yield single patch.
195
173
 
@@ -199,24 +177,18 @@ class PathIterableDataset(IterableDataset):
199
177
  Single patch.
200
178
  """
201
179
  assert (
202
- self.mean is not None and self.std is not None
180
+ self.data_stats.image_stats.means is not None
181
+ and self.data_stats.image_stats.stds is not None
203
182
  ), "Mean and std must be provided"
204
183
 
205
184
  # iterate over files
206
- for sample_input, sample_target in _iterate_over_files(
185
+ for sample_input, sample_target in iterate_over_files(
207
186
  self.data_config, self.data_files, self.target_files, self.read_source_func
208
187
  ):
209
- reshaped_sample = reshape_array(sample_input, self.data_config.axes)
210
- reshaped_target = (
211
- None
212
- if sample_target is None
213
- else reshape_array(sample_target, self.data_config.axes)
214
- )
215
-
216
188
  patches = extract_patches_random(
217
- arr=reshaped_sample,
189
+ arr=sample_input,
218
190
  patch_size=self.data_config.patch_size,
219
- target=reshaped_target,
191
+ target=sample_target,
220
192
  )
221
193
 
222
194
  # iterate over patches
@@ -317,132 +289,3 @@ class PathIterableDataset(IterableDataset):
317
289
  dataset.target_files = val_target_files
318
290
 
319
291
  return dataset
320
-
321
-
322
- class IterablePredictionDataset(IterableDataset):
323
- """
324
- Prediction dataset.
325
-
326
- Parameters
327
- ----------
328
- prediction_config : InferenceConfig
329
- Inference configuration.
330
- src_files : List[Path]
331
- List of data files.
332
- read_source_func : Callable, optional
333
- Read source function for custom types, by default read_tiff.
334
- **kwargs : Any
335
- Additional keyword arguments, unused.
336
-
337
- Attributes
338
- ----------
339
- data_path : Union[str, Path]
340
- Path to the data, must be a directory.
341
- axes : str
342
- Description of axes in format STCZYX.
343
- mean : Optional[float], optional
344
- Expected mean of the dataset, by default None.
345
- std : Optional[float], optional
346
- Expected standard deviation of the dataset, by default None.
347
- patch_transform : Optional[Callable], optional
348
- Patch transform callable, by default None.
349
- """
350
-
351
- def __init__(
352
- self,
353
- prediction_config: InferenceConfig,
354
- src_files: List[Path],
355
- read_source_func: Callable = read_tiff,
356
- **kwargs: Any,
357
- ) -> None:
358
- """Constructor.
359
-
360
- Parameters
361
- ----------
362
- prediction_config : InferenceConfig
363
- Inference configuration.
364
- src_files : List[Path]
365
- List of data files.
366
- read_source_func : Callable, optional
367
- Read source function for custom types, by default read_tiff.
368
- **kwargs : Any
369
- Additional keyword arguments, unused.
370
-
371
- Raises
372
- ------
373
- ValueError
374
- If mean and std are not provided in the inference configuration.
375
- """
376
- self.prediction_config = prediction_config
377
- self.data_files = src_files
378
- self.axes = prediction_config.axes
379
- self.tile_size = self.prediction_config.tile_size
380
- self.tile_overlap = self.prediction_config.tile_overlap
381
- self.read_source_func = read_source_func
382
-
383
- # tile only if both tile size and overlaps are provided
384
- self.tile = self.tile_size is not None and self.tile_overlap is not None
385
-
386
- # check mean and std and create normalize transform
387
- if self.prediction_config.mean is None or self.prediction_config.std is None:
388
- raise ValueError("Mean and std must be provided for prediction.")
389
- else:
390
- self.mean = self.prediction_config.mean
391
- self.std = self.prediction_config.std
392
-
393
- # instantiate normalize transform
394
- self.patch_transform = Compose(
395
- transform_list=[
396
- NormalizeModel(
397
- mean=prediction_config.mean, std=prediction_config.std
398
- )
399
- ],
400
- )
401
-
402
- def __iter__(
403
- self,
404
- ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
405
- """
406
- Iterate over data source and yield single patch.
407
-
408
- Yields
409
- ------
410
- np.ndarray
411
- Single patch.
412
- """
413
- assert (
414
- self.mean is not None and self.std is not None
415
- ), "Mean and std must be provided"
416
-
417
- for sample, _ in _iterate_over_files(
418
- self.prediction_config,
419
- self.data_files,
420
- read_source_func=self.read_source_func,
421
- ):
422
- # reshape array
423
- reshaped_sample = reshape_array(sample, self.axes)
424
-
425
- if (
426
- self.tile
427
- and self.tile_size is not None
428
- and self.tile_overlap is not None
429
- ):
430
- # generate patches, return a generator
431
- patch_gen = extract_tiles(
432
- arr=reshaped_sample,
433
- tile_size=self.tile_size,
434
- overlaps=self.tile_overlap,
435
- )
436
- else:
437
- # just wrap the sample in a generator with default tiling info
438
- array_shape = reshaped_sample.squeeze().shape
439
- patch_gen = (
440
- (reshaped_sample, TileInformation(array_shape=array_shape))
441
- for _ in range(1)
442
- )
443
-
444
- # apply transform to patches
445
- for patch_array, tile_info in patch_gen:
446
- transformed_patch, _ = self.patch_transform(patch=patch_array)
447
-
448
- yield transformed_patch, tile_info
@@ -0,0 +1,121 @@
1
+ """Iterable prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.transforms import Compose
12
+
13
+ from ..config import InferenceConfig
14
+ from ..config.transformations import NormalizeModel
15
+ from .dataset_utils import iterate_over_files, read_tiff
16
+
17
+
18
+ class IterablePredDataset(IterableDataset):
19
+ """Simple iterable prediction dataset.
20
+
21
+ Parameters
22
+ ----------
23
+ prediction_config : InferenceConfig
24
+ Inference configuration.
25
+ src_files : List[Path]
26
+ List of data files.
27
+ read_source_func : Callable, optional
28
+ Read source function for custom types, by default read_tiff.
29
+ **kwargs : Any
30
+ Additional keyword arguments, unused.
31
+
32
+ Attributes
33
+ ----------
34
+ data_path : Union[str, Path]
35
+ Path to the data, must be a directory.
36
+ axes : str
37
+ Description of axes in format STCZYX.
38
+ mean : Optional[float], optional
39
+ Expected mean of the dataset, by default None.
40
+ std : Optional[float], optional
41
+ Expected standard deviation of the dataset, by default None.
42
+ patch_transform : Optional[Callable], optional
43
+ Patch transform callable, by default None.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ prediction_config: InferenceConfig,
49
+ src_files: list[Path],
50
+ read_source_func: Callable = read_tiff,
51
+ **kwargs: Any,
52
+ ) -> None:
53
+ """Constructor.
54
+
55
+ Parameters
56
+ ----------
57
+ prediction_config : InferenceConfig
58
+ Inference configuration.
59
+ src_files : list of pathlib.Path
60
+ List of data files.
61
+ read_source_func : Callable, optional
62
+ Read source function for custom types, by default read_tiff.
63
+ **kwargs : Any
64
+ Additional keyword arguments, unused.
65
+
66
+ Raises
67
+ ------
68
+ ValueError
69
+ If mean and std are not provided in the inference configuration.
70
+ """
71
+ self.prediction_config = prediction_config
72
+ self.data_files = src_files
73
+ self.axes = prediction_config.axes
74
+ self.read_source_func = read_source_func
75
+
76
+ # check mean and std and create normalize transform
77
+ if (
78
+ self.prediction_config.image_means is None
79
+ or self.prediction_config.image_stds is None
80
+ ):
81
+ raise ValueError("Mean and std must be provided for prediction.")
82
+ else:
83
+ self.image_means = self.prediction_config.image_means
84
+ self.image_stds = self.prediction_config.image_stds
85
+
86
+ # instantiate normalize transform
87
+ self.patch_transform = Compose(
88
+ transform_list=[
89
+ NormalizeModel(
90
+ image_means=self.image_means,
91
+ image_stds=self.image_stds,
92
+ )
93
+ ],
94
+ )
95
+
96
+ def __iter__(
97
+ self,
98
+ ) -> Generator[NDArray, None, None]:
99
+ """
100
+ Iterate over data source and yield single patch.
101
+
102
+ Yields
103
+ ------
104
+ NDArray
105
+ Single patch.
106
+ """
107
+ assert (
108
+ self.image_means is not None and self.image_stds is not None
109
+ ), "Mean and std must be provided"
110
+
111
+ for sample, _ in iterate_over_files(
112
+ self.prediction_config,
113
+ self.data_files,
114
+ read_source_func=self.read_source_func,
115
+ ):
116
+ # sample has S dimension
117
+ for i in range(sample.shape[0]):
118
+
119
+ transformed_sample, _ = self.patch_transform(patch=sample[i])
120
+
121
+ yield transformed_sample
@@ -0,0 +1,139 @@
1
+ """Iterable tiled prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.transforms import Compose
12
+
13
+ from ..config import InferenceConfig
14
+ from ..config.tile_information import TileInformation
15
+ from ..config.transformations import NormalizeModel
16
+ from .dataset_utils import iterate_over_files, read_tiff
17
+ from .tiling import extract_tiles
18
+
19
+
20
+ class IterableTiledPredDataset(IterableDataset):
21
+ """Tiled prediction dataset.
22
+
23
+ Parameters
24
+ ----------
25
+ prediction_config : InferenceConfig
26
+ Inference configuration.
27
+ src_files : list of pathlib.Path
28
+ List of data files.
29
+ read_source_func : Callable, optional
30
+ Read source function for custom types, by default read_tiff.
31
+ **kwargs : Any
32
+ Additional keyword arguments, unused.
33
+
34
+ Attributes
35
+ ----------
36
+ data_path : str or pathlib.Path
37
+ Path to the data, must be a directory.
38
+ axes : str
39
+ Description of axes in format STCZYX.
40
+ mean : float, optional
41
+ Expected mean of the dataset, by default None.
42
+ std : float, optional
43
+ Expected standard deviation of the dataset, by default None.
44
+ patch_transform : Callable, optional
45
+ Patch transform callable, by default None.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ prediction_config: InferenceConfig,
51
+ src_files: list[Path],
52
+ read_source_func: Callable = read_tiff,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """Constructor.
56
+
57
+ Parameters
58
+ ----------
59
+ prediction_config : InferenceConfig
60
+ Inference configuration.
61
+ src_files : List[Path]
62
+ List of data files.
63
+ read_source_func : Callable, optional
64
+ Read source function for custom types, by default read_tiff.
65
+ **kwargs : Any
66
+ Additional keyword arguments, unused.
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ If mean and std are not provided in the inference configuration.
72
+ """
73
+ if (
74
+ prediction_config.tile_size is None
75
+ or prediction_config.tile_overlap is None
76
+ ):
77
+ raise ValueError(
78
+ "Tile size and overlap must be provided for tiled prediction."
79
+ )
80
+
81
+ self.prediction_config = prediction_config
82
+ self.data_files = src_files
83
+ self.axes = prediction_config.axes
84
+ self.tile_size = prediction_config.tile_size
85
+ self.tile_overlap = prediction_config.tile_overlap
86
+ self.read_source_func = read_source_func
87
+
88
+ # check mean and std and create normalize transform
89
+ if (
90
+ self.prediction_config.image_means is None
91
+ or self.prediction_config.image_stds is None
92
+ ):
93
+ raise ValueError("Mean and std must be provided for prediction.")
94
+ else:
95
+ self.image_means = self.prediction_config.image_means
96
+ self.image_stds = self.prediction_config.image_stds
97
+
98
+ # instantiate normalize transform
99
+ self.patch_transform = Compose(
100
+ transform_list=[
101
+ NormalizeModel(
102
+ image_means=self.image_means,
103
+ image_stds=self.image_stds,
104
+ )
105
+ ],
106
+ )
107
+
108
+ def __iter__(
109
+ self,
110
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
111
+ """
112
+ Iterate over data source and yield single patch.
113
+
114
+ Yields
115
+ ------
116
+ Generator of NDArray and TileInformation tuple
117
+ Generator of single tiles.
118
+ """
119
+ assert (
120
+ self.image_means is not None and self.image_stds is not None
121
+ ), "Mean and std must be provided"
122
+
123
+ for sample, _ in iterate_over_files(
124
+ self.prediction_config,
125
+ self.data_files,
126
+ read_source_func=self.read_source_func,
127
+ ):
128
+ # generate patches, return a generator of single tiles
129
+ patch_gen = extract_tiles(
130
+ arr=sample,
131
+ tile_size=self.tile_size,
132
+ overlaps=self.tile_overlap,
133
+ )
134
+
135
+ # apply transform to patches
136
+ for patch_array, tile_info in patch_gen:
137
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
138
+
139
+ yield transformed_patch, tile_info