careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,295 @@
1
+ """Iterable dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from collections.abc import Generator
7
+ from pathlib import Path
8
+ from typing import Callable, Optional
9
+
10
+ import numpy as np
11
+ from torch.utils.data import IterableDataset
12
+
13
+ from careamics.config import DataConfig
14
+ from careamics.config.transformations import NormalizeModel
15
+ from careamics.file_io.read import read_tiff
16
+ from careamics.transforms import Compose
17
+
18
+ from ..utils.logging import get_logger
19
+ from .dataset_utils import iterate_over_files
20
+ from .dataset_utils.running_stats import WelfordStatistics
21
+ from .patching.patching import Stats
22
+ from .patching.random_patching import extract_patches_random
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class PathIterableDataset(IterableDataset):
28
+ """
29
+ Dataset allowing extracting patches w/o loading whole data into memory.
30
+
31
+ Parameters
32
+ ----------
33
+ data_config : DataConfig
34
+ Data configuration.
35
+ src_files : list of pathlib.Path
36
+ List of data files.
37
+ target_files : list of pathlib.Path, optional
38
+ Optional list of target files, by default None.
39
+ read_source_func : Callable, optional
40
+ Read source function for custom types, by default read_tiff.
41
+
42
+ Attributes
43
+ ----------
44
+ data_path : list of pathlib.Path
45
+ Path to the data, must be a directory.
46
+ axes : str
47
+ Description of axes in format STCZYX.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ data_config: DataConfig,
53
+ src_files: list[Path],
54
+ target_files: Optional[list[Path]] = None,
55
+ read_source_func: Callable = read_tiff,
56
+ ) -> None:
57
+ """Constructors.
58
+
59
+ Parameters
60
+ ----------
61
+ data_config : DataConfig
62
+ Data configuration.
63
+ src_files : list[Path]
64
+ List of data files.
65
+ target_files : list[Path] or None, optional
66
+ Optional list of target files, by default None.
67
+ read_source_func : Callable, optional
68
+ Read source function for custom types, by default read_tiff.
69
+ """
70
+ self.data_config = data_config
71
+ self.data_files = src_files
72
+ self.target_files = target_files
73
+ self.read_source_func = read_source_func
74
+
75
+ # compute mean and std over the dataset
76
+ # only checking the image_mean because the DataConfig class ensures that
77
+ # if image_mean is provided, image_std is also provided
78
+ if not self.data_config.image_means:
79
+ self.image_stats, self.target_stats = self._calculate_mean_and_std()
80
+ logger.info(
81
+ f"Computed dataset mean: {self.image_stats.means},"
82
+ f"std: {self.image_stats.stds}"
83
+ )
84
+
85
+ # update the mean in the config
86
+ self.data_config.set_means_and_stds(
87
+ image_means=self.image_stats.means,
88
+ image_stds=self.image_stats.stds,
89
+ target_means=(
90
+ list(self.target_stats.means)
91
+ if self.target_stats.means is not None
92
+ else None
93
+ ),
94
+ target_stds=(
95
+ list(self.target_stats.stds)
96
+ if self.target_stats.stds is not None
97
+ else None
98
+ ),
99
+ )
100
+
101
+ else:
102
+ # if mean and std are provided in the config, use them
103
+ self.image_stats, self.target_stats = (
104
+ Stats(self.data_config.image_means, self.data_config.image_stds),
105
+ Stats(self.data_config.target_means, self.data_config.target_stds),
106
+ )
107
+
108
+ # create transform composed of normalization and other transforms
109
+ self.patch_transform = Compose(
110
+ transform_list=[
111
+ NormalizeModel(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
116
+ )
117
+ ]
118
+ + data_config.transforms
119
+ )
120
+
121
+ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
122
+ """
123
+ Calculate mean and std of the dataset.
124
+
125
+ Returns
126
+ -------
127
+ tuple of Stats and optional Stats
128
+ Data classes containing the image and target statistics.
129
+ """
130
+ num_samples = 0
131
+ image_stats = WelfordStatistics()
132
+ if self.target_files is not None:
133
+ target_stats = WelfordStatistics()
134
+
135
+ for sample, target in iterate_over_files(
136
+ self.data_config, self.data_files, self.target_files, self.read_source_func
137
+ ):
138
+ # update the image statistics
139
+ image_stats.update(sample, num_samples)
140
+
141
+ # update the target statistics if target is available
142
+ if target is not None:
143
+ target_stats.update(target, num_samples)
144
+
145
+ num_samples += 1
146
+
147
+ if num_samples == 0:
148
+ raise ValueError("No samples found in the dataset.")
149
+
150
+ # Average the means and stds per sample
151
+ image_means, image_stds = image_stats.finalize()
152
+
153
+ if target is not None:
154
+ target_means, target_stds = target_stats.finalize()
155
+
156
+ return (
157
+ Stats(image_means, image_stds),
158
+ Stats(np.array(target_means), np.array(target_stds)),
159
+ )
160
+ else:
161
+ return Stats(image_means, image_stds), Stats(None, None)
162
+
163
+ def __iter__(
164
+ self,
165
+ ) -> Generator[tuple[np.ndarray, ...], None, None]:
166
+ """
167
+ Iterate over data source and yield single patch.
168
+
169
+ Yields
170
+ ------
171
+ np.ndarray
172
+ Single patch.
173
+ """
174
+ assert (
175
+ self.image_stats.means is not None and self.image_stats.stds is not None
176
+ ), "Mean and std must be provided"
177
+
178
+ # iterate over files
179
+ for sample_input, sample_target in iterate_over_files(
180
+ self.data_config, self.data_files, self.target_files, self.read_source_func
181
+ ):
182
+ patches = extract_patches_random(
183
+ arr=sample_input,
184
+ patch_size=self.data_config.patch_size,
185
+ target=sample_target,
186
+ )
187
+
188
+ # iterate over patches
189
+ # patches are tuples of (patch, target) if target is available
190
+ # or (patch, None) only if no target is available
191
+ # patch is of dimensions (C)ZYX
192
+ for patch_data in patches:
193
+ yield self.patch_transform(
194
+ patch=patch_data[0],
195
+ target=patch_data[1],
196
+ )
197
+
198
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
199
+ """Return training data statistics.
200
+
201
+ Returns
202
+ -------
203
+ tuple of list of floats
204
+ Means and standard deviations across channels of the training data.
205
+ """
206
+ return self.image_stats.get_statistics()
207
+
208
+ def get_number_of_files(self) -> int:
209
+ """
210
+ Return the number of files in the dataset.
211
+
212
+ Returns
213
+ -------
214
+ int
215
+ Number of files in the dataset.
216
+ """
217
+ return len(self.data_files)
218
+
219
+ def split_dataset(
220
+ self,
221
+ percentage: float = 0.1,
222
+ minimum_number: int = 5,
223
+ ) -> PathIterableDataset:
224
+ """Split up dataset in two.
225
+
226
+ Splits the datest sing a percentage of the data (files) to extract, or the
227
+ minimum number of the percentage is less than the minimum number.
228
+
229
+ Parameters
230
+ ----------
231
+ percentage : float, optional
232
+ Percentage of files to split up, by default 0.1.
233
+ minimum_number : int, optional
234
+ Minimum number of files to split up, by default 5.
235
+
236
+ Returns
237
+ -------
238
+ IterableDataset
239
+ Dataset containing the split data.
240
+
241
+ Raises
242
+ ------
243
+ ValueError
244
+ If the percentage is smaller than 0 or larger than 1.
245
+ ValueError
246
+ If the minimum number is smaller than 1 or larger than the number of files.
247
+ """
248
+ if percentage < 0 or percentage > 1:
249
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
250
+
251
+ if minimum_number < 1 or minimum_number > self.get_number_of_files():
252
+ raise ValueError(
253
+ f"Minimum number of files must be between 1 and "
254
+ f"{self.get_number_of_files()} (number of files), got "
255
+ f"{minimum_number}."
256
+ )
257
+
258
+ # compute number of files
259
+ total_files = self.get_number_of_files()
260
+ n_files = max(round(percentage * total_files), minimum_number)
261
+
262
+ # get random indices
263
+ indices = np.random.choice(total_files, n_files, replace=False)
264
+
265
+ # extract files
266
+ val_files = [self.data_files[i] for i in indices]
267
+
268
+ # remove patches from self.patch
269
+ data_files = []
270
+ for i, file in enumerate(self.data_files):
271
+ if i not in indices:
272
+ data_files.append(file)
273
+ self.data_files = data_files
274
+
275
+ # same for targets
276
+ if self.target_files is not None:
277
+ val_target_files = [self.target_files[i] for i in indices]
278
+
279
+ data_target_files = []
280
+ for i, file in enumerate(self.target_files):
281
+ if i not in indices:
282
+ data_target_files.append(file)
283
+ self.target_files = data_target_files
284
+
285
+ # clone the dataset
286
+ dataset = copy.deepcopy(self)
287
+
288
+ # reassign patches
289
+ dataset.data_files = val_files
290
+
291
+ # reassign targets
292
+ if self.target_files is not None:
293
+ dataset.target_files = val_target_files
294
+
295
+ return dataset
@@ -0,0 +1,122 @@
1
+ """Iterable prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.file_io.read import read_tiff
12
+ from careamics.transforms import Compose
13
+
14
+ from ..config import InferenceConfig
15
+ from ..config.transformations import NormalizeModel
16
+ from .dataset_utils import iterate_over_files
17
+
18
+
19
+ class IterablePredDataset(IterableDataset):
20
+ """Simple iterable prediction dataset.
21
+
22
+ Parameters
23
+ ----------
24
+ prediction_config : InferenceConfig
25
+ Inference configuration.
26
+ src_files : List[Path]
27
+ List of data files.
28
+ read_source_func : Callable, optional
29
+ Read source function for custom types, by default read_tiff.
30
+ **kwargs : Any
31
+ Additional keyword arguments, unused.
32
+
33
+ Attributes
34
+ ----------
35
+ data_path : Union[str, Path]
36
+ Path to the data, must be a directory.
37
+ axes : str
38
+ Description of axes in format STCZYX.
39
+ mean : Optional[float], optional
40
+ Expected mean of the dataset, by default None.
41
+ std : Optional[float], optional
42
+ Expected standard deviation of the dataset, by default None.
43
+ patch_transform : Optional[Callable], optional
44
+ Patch transform callable, by default None.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ prediction_config: InferenceConfig,
50
+ src_files: list[Path],
51
+ read_source_func: Callable = read_tiff,
52
+ **kwargs: Any,
53
+ ) -> None:
54
+ """Constructor.
55
+
56
+ Parameters
57
+ ----------
58
+ prediction_config : InferenceConfig
59
+ Inference configuration.
60
+ src_files : list of pathlib.Path
61
+ List of data files.
62
+ read_source_func : Callable, optional
63
+ Read source function for custom types, by default read_tiff.
64
+ **kwargs : Any
65
+ Additional keyword arguments, unused.
66
+
67
+ Raises
68
+ ------
69
+ ValueError
70
+ If mean and std are not provided in the inference configuration.
71
+ """
72
+ self.prediction_config = prediction_config
73
+ self.data_files = src_files
74
+ self.axes = prediction_config.axes
75
+ self.read_source_func = read_source_func
76
+
77
+ # check mean and std and create normalize transform
78
+ if (
79
+ self.prediction_config.image_means is None
80
+ or self.prediction_config.image_stds is None
81
+ ):
82
+ raise ValueError("Mean and std must be provided for prediction.")
83
+ else:
84
+ self.image_means = self.prediction_config.image_means
85
+ self.image_stds = self.prediction_config.image_stds
86
+
87
+ # instantiate normalize transform
88
+ self.patch_transform = Compose(
89
+ transform_list=[
90
+ NormalizeModel(
91
+ image_means=self.image_means,
92
+ image_stds=self.image_stds,
93
+ )
94
+ ],
95
+ )
96
+
97
+ def __iter__(
98
+ self,
99
+ ) -> Generator[NDArray, None, None]:
100
+ """
101
+ Iterate over data source and yield single patch.
102
+
103
+ Yields
104
+ ------
105
+ NDArray
106
+ Single patch.
107
+ """
108
+ assert (
109
+ self.image_means is not None and self.image_stds is not None
110
+ ), "Mean and std must be provided"
111
+
112
+ for sample, _ in iterate_over_files(
113
+ self.prediction_config,
114
+ self.data_files,
115
+ read_source_func=self.read_source_func,
116
+ ):
117
+ # sample has S dimension
118
+ for i in range(sample.shape[0]):
119
+
120
+ transformed_sample, _ = self.patch_transform(patch=sample[i])
121
+
122
+ yield transformed_sample
@@ -0,0 +1,140 @@
1
+ """Iterable tiled prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.file_io.read import read_tiff
12
+ from careamics.transforms import Compose
13
+
14
+ from ..config import InferenceConfig
15
+ from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
17
+ from .dataset_utils import iterate_over_files
18
+ from .tiling import extract_tiles
19
+
20
+
21
+ class IterableTiledPredDataset(IterableDataset):
22
+ """Tiled prediction dataset.
23
+
24
+ Parameters
25
+ ----------
26
+ prediction_config : InferenceConfig
27
+ Inference configuration.
28
+ src_files : list of pathlib.Path
29
+ List of data files.
30
+ read_source_func : Callable, optional
31
+ Read source function for custom types, by default read_tiff.
32
+ **kwargs : Any
33
+ Additional keyword arguments, unused.
34
+
35
+ Attributes
36
+ ----------
37
+ data_path : str or pathlib.Path
38
+ Path to the data, must be a directory.
39
+ axes : str
40
+ Description of axes in format STCZYX.
41
+ mean : float, optional
42
+ Expected mean of the dataset, by default None.
43
+ std : float, optional
44
+ Expected standard deviation of the dataset, by default None.
45
+ patch_transform : Callable, optional
46
+ Patch transform callable, by default None.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ prediction_config: InferenceConfig,
52
+ src_files: list[Path],
53
+ read_source_func: Callable = read_tiff,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ """Constructor.
57
+
58
+ Parameters
59
+ ----------
60
+ prediction_config : InferenceConfig
61
+ Inference configuration.
62
+ src_files : List[Path]
63
+ List of data files.
64
+ read_source_func : Callable, optional
65
+ Read source function for custom types, by default read_tiff.
66
+ **kwargs : Any
67
+ Additional keyword arguments, unused.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If mean and std are not provided in the inference configuration.
73
+ """
74
+ if (
75
+ prediction_config.tile_size is None
76
+ or prediction_config.tile_overlap is None
77
+ ):
78
+ raise ValueError(
79
+ "Tile size and overlap must be provided for tiled prediction."
80
+ )
81
+
82
+ self.prediction_config = prediction_config
83
+ self.data_files = src_files
84
+ self.axes = prediction_config.axes
85
+ self.tile_size = prediction_config.tile_size
86
+ self.tile_overlap = prediction_config.tile_overlap
87
+ self.read_source_func = read_source_func
88
+
89
+ # check mean and std and create normalize transform
90
+ if (
91
+ self.prediction_config.image_means is None
92
+ or self.prediction_config.image_stds is None
93
+ ):
94
+ raise ValueError("Mean and std must be provided for prediction.")
95
+ else:
96
+ self.image_means = self.prediction_config.image_means
97
+ self.image_stds = self.prediction_config.image_stds
98
+
99
+ # instantiate normalize transform
100
+ self.patch_transform = Compose(
101
+ transform_list=[
102
+ NormalizeModel(
103
+ image_means=self.image_means,
104
+ image_stds=self.image_stds,
105
+ )
106
+ ],
107
+ )
108
+
109
+ def __iter__(
110
+ self,
111
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
112
+ """
113
+ Iterate over data source and yield single patch.
114
+
115
+ Yields
116
+ ------
117
+ Generator of NDArray and TileInformation tuple
118
+ Generator of single tiles.
119
+ """
120
+ assert (
121
+ self.image_means is not None and self.image_stds is not None
122
+ ), "Mean and std must be provided"
123
+
124
+ for sample, _ in iterate_over_files(
125
+ self.prediction_config,
126
+ self.data_files,
127
+ read_source_func=self.read_source_func,
128
+ ):
129
+ # generate patches, return a generator of single tiles
130
+ patch_gen = extract_tiles(
131
+ arr=sample,
132
+ tile_size=self.tile_size,
133
+ overlaps=self.tile_overlap,
134
+ )
135
+
136
+ # apply transform to patches
137
+ for patch_array, tile_info in patch_gen:
138
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
139
+
140
+ yield transformed_patch, tile_info
@@ -0,0 +1 @@
1
+ """Patching and tiling functions."""