careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (81) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +80 -44
  4. careamics/config/algorithm_model.py +5 -3
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +8 -1
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -2
  12. careamics/config/configuration_factory.py +4 -16
  13. careamics/config/data_model.py +10 -14
  14. careamics/config/inference_model.py +0 -65
  15. careamics/config/optimizer_models.py +4 -4
  16. careamics/config/support/__init__.py +0 -2
  17. careamics/config/support/supported_activations.py +2 -0
  18. careamics/config/support/supported_algorithms.py +3 -1
  19. careamics/config/support/supported_architectures.py +2 -0
  20. careamics/config/support/supported_data.py +2 -0
  21. careamics/config/support/supported_loggers.py +2 -0
  22. careamics/config/support/supported_losses.py +2 -0
  23. careamics/config/support/supported_optimizers.py +2 -0
  24. careamics/config/support/supported_pixel_manipulations.py +3 -3
  25. careamics/config/support/supported_struct_axis.py +2 -0
  26. careamics/config/support/supported_transforms.py +4 -15
  27. careamics/config/tile_information.py +2 -0
  28. careamics/config/transformations/__init__.py +3 -2
  29. careamics/config/transformations/xy_flip_model.py +43 -0
  30. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  31. careamics/conftest.py +12 -0
  32. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  33. careamics/dataset/dataset_utils/file_utils.py +4 -3
  34. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  35. careamics/dataset/dataset_utils/read_utils.py +2 -0
  36. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  37. careamics/dataset/in_memory_dataset.py +71 -32
  38. careamics/dataset/iterable_dataset.py +155 -68
  39. careamics/dataset/patching/patching.py +56 -15
  40. careamics/dataset/patching/random_patching.py +8 -2
  41. careamics/dataset/patching/sequential_patching.py +14 -8
  42. careamics/dataset/patching/tiled_patching.py +3 -1
  43. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  44. careamics/dataset/zarr_dataset.py +2 -0
  45. careamics/lightning_datamodule.py +45 -19
  46. careamics/lightning_module.py +8 -2
  47. careamics/lightning_prediction_datamodule.py +3 -13
  48. careamics/lightning_prediction_loop.py +8 -6
  49. careamics/losses/__init__.py +2 -3
  50. careamics/losses/loss_factory.py +1 -1
  51. careamics/losses/losses.py +11 -7
  52. careamics/model_io/bmz_io.py +3 -3
  53. careamics/models/activation.py +2 -0
  54. careamics/models/layers.py +121 -25
  55. careamics/models/model_factory.py +1 -1
  56. careamics/models/unet.py +35 -14
  57. careamics/prediction/stitch_prediction.py +2 -6
  58. careamics/transforms/__init__.py +2 -2
  59. careamics/transforms/compose.py +33 -7
  60. careamics/transforms/n2v_manipulate.py +49 -13
  61. careamics/transforms/normalize.py +55 -3
  62. careamics/transforms/pixel_manipulation.py +5 -5
  63. careamics/transforms/struct_mask_parameters.py +3 -1
  64. careamics/transforms/transform.py +10 -19
  65. careamics/transforms/xy_flip.py +123 -0
  66. careamics/transforms/xy_random_rotate90.py +38 -5
  67. careamics/utils/base_enum.py +28 -0
  68. careamics/utils/path_utils.py +2 -0
  69. careamics/utils/ram.py +2 -0
  70. careamics/utils/receptive_field.py +93 -87
  71. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
  72. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  73. careamics/config/noise_models.py +0 -162
  74. careamics/config/support/supported_extraction_strategies.py +0 -25
  75. careamics/config/transformations/nd_flip_model.py +0 -27
  76. careamics/losses/noise_model_factory.py +0 -40
  77. careamics/losses/noise_models.py +0 -524
  78. careamics/transforms/nd_flip.py +0 -67
  79. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  80. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  81. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -13,6 +13,7 @@ from careamics.transforms import Compose
13
13
 
14
14
  from ..config import DataConfig, InferenceConfig
15
15
  from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
16
17
  from ..utils.logging import get_logger
17
18
  from .dataset_utils import read_tiff, reshape_array
18
19
  from .patching.patching import (
@@ -27,24 +28,49 @@ logger = get_logger(__name__)
27
28
 
28
29
 
29
30
  class InMemoryDataset(Dataset):
30
- """Dataset storing data in memory and allowing generating patches from it."""
31
+ """Dataset storing data in memory and allowing generating patches from it.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : DataConfig
36
+ Data configuration.
37
+ inputs : Union[np.ndarray, List[Path]]
38
+ Input data.
39
+ input_target : Optional[Union[np.ndarray, List[Path]]], optional
40
+ Target data, by default None.
41
+ read_source_func : Callable, optional
42
+ Read source function for custom types, by default read_tiff.
43
+ **kwargs : Any
44
+ Additional keyword arguments, unused.
45
+ """
31
46
 
32
47
  def __init__(
33
48
  self,
34
49
  data_config: DataConfig,
35
50
  inputs: Union[np.ndarray, List[Path]],
36
- data_target: Optional[Union[np.ndarray, List[Path]]] = None,
51
+ input_target: Optional[Union[np.ndarray, List[Path]]] = None,
37
52
  read_source_func: Callable = read_tiff,
38
53
  **kwargs: Any,
39
54
  ) -> None:
40
55
  """
41
56
  Constructor.
42
57
 
43
- # TODO
58
+ Parameters
59
+ ----------
60
+ data_config : DataConfig
61
+ Data configuration.
62
+ inputs : Union[np.ndarray, List[Path]]
63
+ Input data.
64
+ input_target : Optional[Union[np.ndarray, List[Path]]], optional
65
+ Target data, by default None.
66
+ read_source_func : Callable, optional
67
+ Read source function for custom types, by default read_tiff.
68
+ **kwargs : Any
69
+ Additional keyword arguments, unused.
44
70
  """
45
71
  self.data_config = data_config
46
72
  self.inputs = inputs
47
- self.data_target = data_target
73
+ self.input_targets = input_target
48
74
  self.axes = self.data_config.axes
49
75
  self.patch_size = self.data_config.patch_size
50
76
 
@@ -52,11 +78,11 @@ class InMemoryDataset(Dataset):
52
78
  self.read_source_func = read_source_func
53
79
 
54
80
  # Generate patches
55
- supervised = self.data_target is not None
56
- patches = self._prepare_patches(supervised)
81
+ supervised = self.input_targets is not None
82
+ patch_data = self._prepare_patches(supervised)
57
83
 
58
84
  # Add results to members
59
- self.data, self.data_targets, computed_mean, computed_std = patches
85
+ self.patches, self.patch_targets, computed_mean, computed_std = patch_data
60
86
 
61
87
  if not self.data_config.mean or not self.data_config.std:
62
88
  self.mean, self.std = computed_mean, computed_std
@@ -91,18 +117,18 @@ class InMemoryDataset(Dataset):
91
117
  """
92
118
  if supervised:
93
119
  if isinstance(self.inputs, np.ndarray) and isinstance(
94
- self.data_target, np.ndarray
120
+ self.input_targets, np.ndarray
95
121
  ):
96
122
  return prepare_patches_supervised_array(
97
123
  self.inputs,
98
124
  self.axes,
99
- self.data_target,
125
+ self.input_targets,
100
126
  self.patch_size,
101
127
  )
102
- elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
128
+ elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
103
129
  return prepare_patches_supervised(
104
130
  self.inputs,
105
- self.data_target,
131
+ self.input_targets,
106
132
  self.axes,
107
133
  self.patch_size,
108
134
  self.read_source_func,
@@ -111,7 +137,7 @@ class InMemoryDataset(Dataset):
111
137
  raise ValueError(
112
138
  f"Data and target must be of the same type, either both numpy "
113
139
  f"arrays or both lists of paths, got {type(self.inputs)} (data) "
114
- f"and {type(self.data_target)} (target)."
140
+ f"and {type(self.input_targets)} (target)."
115
141
  )
116
142
  else:
117
143
  if isinstance(self.inputs, np.ndarray):
@@ -137,9 +163,9 @@ class InMemoryDataset(Dataset):
137
163
  int
138
164
  Length of the dataset.
139
165
  """
140
- return len(self.data)
166
+ return len(self.patches)
141
167
 
142
- def __getitem__(self, index: int) -> Tuple[np.ndarray]:
168
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
143
169
  """
144
170
  Return the patch corresponding to the provided index.
145
171
 
@@ -158,12 +184,12 @@ class InMemoryDataset(Dataset):
158
184
  ValueError
159
185
  If dataset mean and std are not set.
160
186
  """
161
- patch = self.data[index]
187
+ patch = self.patches[index]
162
188
 
163
189
  # if there is a target
164
- if self.data_target is not None:
190
+ if self.patch_targets is not None:
165
191
  # get target
166
- target = self.data_targets[index]
192
+ target = self.patch_targets[index]
167
193
 
168
194
  return self.patch_transform(patch=patch, target=target)
169
195
 
@@ -223,25 +249,25 @@ class InMemoryDataset(Dataset):
223
249
  indices = np.random.choice(total_patches, n_patches, replace=False)
224
250
 
225
251
  # extract patches
226
- val_patches = self.data[indices]
252
+ val_patches = self.patches[indices]
227
253
 
228
254
  # remove patches from self.patch
229
- self.data = np.delete(self.data, indices, axis=0)
255
+ self.patches = np.delete(self.patches, indices, axis=0)
230
256
 
231
257
  # same for targets
232
- if self.data_targets is not None:
233
- val_targets = self.data_targets[indices]
234
- self.data_targets = np.delete(self.data_targets, indices, axis=0)
258
+ if self.patch_targets is not None:
259
+ val_targets = self.patch_targets[indices]
260
+ self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
235
261
 
236
262
  # clone the dataset
237
263
  dataset = copy.deepcopy(self)
238
264
 
239
265
  # reassign patches
240
- dataset.data = val_patches
266
+ dataset.patches = val_patches
241
267
 
242
268
  # reassign targets
243
- if self.data_targets is not None:
244
- dataset.data_targets = val_targets
269
+ if self.patch_targets is not None:
270
+ dataset.patch_targets = val_targets
245
271
 
246
272
  return dataset
247
273
 
@@ -250,7 +276,16 @@ class InMemoryPredictionDataset(Dataset):
250
276
  """
251
277
  Dataset storing data in memory and allowing generating patches from it.
252
278
 
253
- # TODO
279
+ Parameters
280
+ ----------
281
+ prediction_config : InferenceConfig
282
+ Prediction configuration.
283
+ inputs : np.ndarray
284
+ Input data.
285
+ data_target : Optional[np.ndarray], optional
286
+ Target data, by default None.
287
+ read_source_func : Optional[Callable], optional
288
+ Read source function for custom types, by default read_tiff.
254
289
  """
255
290
 
256
291
  def __init__(
@@ -264,10 +299,14 @@ class InMemoryPredictionDataset(Dataset):
264
299
 
265
300
  Parameters
266
301
  ----------
267
- array : np.ndarray
268
- Array containing the data.
269
- axes : str
270
- Description of axes in format STCZYX.
302
+ prediction_config : InferenceConfig
303
+ Prediction configuration.
304
+ inputs : np.ndarray
305
+ Input data.
306
+ data_target : Optional[np.ndarray], optional
307
+ Target data, by default None.
308
+ read_source_func : Optional[Callable], optional
309
+ Read source function for custom types, by default read_tiff.
271
310
 
272
311
  Raises
273
312
  ------
@@ -295,7 +334,7 @@ class InMemoryPredictionDataset(Dataset):
295
334
 
296
335
  # get transforms
297
336
  self.patch_transform = Compose(
298
- transform_list=self.pred_config.transforms,
337
+ transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
299
338
  )
300
339
 
301
340
  def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
@@ -310,7 +349,7 @@ class InMemoryPredictionDataset(Dataset):
310
349
  # reshape array
311
350
  reshaped_sample = reshape_array(self.input_array, self.axes)
312
351
 
313
- if self.tiling:
352
+ if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
314
353
  # generate patches, which returns a generator
315
354
  patch_generator = extract_tiles(
316
355
  arr=reshaped_sample,
@@ -1,3 +1,5 @@
1
+ """Iterable dataset used to load data file by file."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import copy
@@ -11,6 +13,7 @@ from careamics.transforms import Compose
11
13
 
12
14
  from ..config import DataConfig, InferenceConfig
13
15
  from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
14
17
  from ..utils.logging import get_logger
15
18
  from .dataset_utils import read_tiff, reshape_array
16
19
  from .patching.random_patching import extract_patches_random
@@ -19,13 +22,85 @@ from .patching.tiled_patching import extract_tiles
19
22
  logger = get_logger(__name__)
20
23
 
21
24
 
25
+ def _iterate_over_files(
26
+ data_config: Union[DataConfig, InferenceConfig],
27
+ data_files: List[Path],
28
+ target_files: Optional[List[Path]] = None,
29
+ read_source_func: Callable = read_tiff,
30
+ ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
31
+ """
32
+ Iterate over data source and yield whole image.
33
+
34
+ Parameters
35
+ ----------
36
+ data_config : Union[DataConfig, InferenceConfig]
37
+ Data configuration.
38
+ data_files : List[Path]
39
+ List of data files.
40
+ target_files : Optional[List[Path]]
41
+ List of target files, by default None.
42
+ read_source_func : Optional[Callable]
43
+ Function to read the source, by default read_tiff.
44
+
45
+ Yields
46
+ ------
47
+ np.ndarray
48
+ Image.
49
+ """
50
+ # When num_workers > 0, each worker process will have a different copy of the
51
+ # dataset object
52
+ # Configuring each copy independently to avoid having duplicate data returned
53
+ # from the workers
54
+ worker_info = get_worker_info()
55
+ worker_id = worker_info.id if worker_info is not None else 0
56
+ num_workers = worker_info.num_workers if worker_info is not None else 1
57
+
58
+ # iterate over the files
59
+ for i, filename in enumerate(data_files):
60
+ # retrieve file corresponding to the worker id
61
+ if i % num_workers == worker_id:
62
+ try:
63
+ # read data
64
+ sample = read_source_func(filename, data_config.axes)
65
+
66
+ # read target, if available
67
+ if target_files is not None:
68
+ if filename.name != target_files[i].name:
69
+ raise ValueError(
70
+ f"File {filename} does not match target file "
71
+ f"{target_files[i]}. Have you passed sorted "
72
+ f"arrays?"
73
+ )
74
+
75
+ # read target
76
+ target = read_source_func(target_files[i], data_config.axes)
77
+
78
+ yield sample, target
79
+ else:
80
+ yield sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
84
+
85
+
22
86
  class PathIterableDataset(IterableDataset):
23
87
  """
24
88
  Dataset allowing extracting patches w/o loading whole data into memory.
25
89
 
26
90
  Parameters
27
91
  ----------
28
- data_path : Union[str, Path]
92
+ data_config : DataConfig
93
+ Data configuration.
94
+ src_files : List[Path]
95
+ List of data files.
96
+ target_files : Optional[List[Path]], optional
97
+ Optional list of target files, by default None.
98
+ read_source_func : Callable, optional
99
+ Read source function for custom types, by default read_tiff.
100
+
101
+ Attributes
102
+ ----------
103
+ data_path : List[Path]
29
104
  Path to the data, must be a directory.
30
105
  axes : str
31
106
  Description of axes in format STCZYX.
@@ -45,11 +120,24 @@ class PathIterableDataset(IterableDataset):
45
120
 
46
121
  def __init__(
47
122
  self,
48
- data_config: Union[DataConfig, InferenceConfig],
123
+ data_config: DataConfig,
49
124
  src_files: List[Path],
50
125
  target_files: Optional[List[Path]] = None,
51
126
  read_source_func: Callable = read_tiff,
52
127
  ) -> None:
128
+ """Constructors.
129
+
130
+ Parameters
131
+ ----------
132
+ data_config : DataConfig
133
+ Data configuration.
134
+ src_files : List[Path]
135
+ List of data files.
136
+ target_files : Optional[List[Path]], optional
137
+ Optional list of target files, by default None.
138
+ read_source_func : Callable, optional
139
+ Read source function for custom types, by default read_tiff.
140
+ """
53
141
  self.data_config = data_config
54
142
  self.data_files = src_files
55
143
  self.target_files = target_files
@@ -82,7 +170,9 @@ class PathIterableDataset(IterableDataset):
82
170
  means, stds = 0, 0
83
171
  num_samples = 0
84
172
 
85
- for sample, _ in self._iterate_over_files():
173
+ for sample, _ in _iterate_over_files(
174
+ self.data_config, self.data_files, self.target_files, self.read_source_func
175
+ ):
86
176
  means += sample.mean()
87
177
  stds += sample.std()
88
178
  num_samples += 1
@@ -97,57 +187,9 @@ class PathIterableDataset(IterableDataset):
97
187
  logger.info(f"Mean: {result_mean}, std: {result_std}")
98
188
  return result_mean, result_std
99
189
 
100
- def _iterate_over_files(
101
- self,
102
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
103
- """
104
- Iterate over data source and yield whole image.
105
-
106
- Yields
107
- ------
108
- np.ndarray
109
- Image.
110
- """
111
- # When num_workers > 0, each worker process will have a different copy of the
112
- # dataset object
113
- # Configuring each copy independently to avoid having duplicate data returned
114
- # from the workers
115
- worker_info = get_worker_info()
116
- worker_id = worker_info.id if worker_info is not None else 0
117
- num_workers = worker_info.num_workers if worker_info is not None else 1
118
-
119
- # iterate over the files
120
- for i, filename in enumerate(self.data_files):
121
- # retrieve file corresponding to the worker id
122
- if i % num_workers == worker_id:
123
- try:
124
- # read data
125
- sample = self.read_source_func(filename, self.data_config.axes)
126
-
127
- # read target, if available
128
- if self.target_files is not None:
129
- if filename.name != self.target_files[i].name:
130
- raise ValueError(
131
- f"File {filename} does not match target file "
132
- f"{self.target_files[i]}. Have you passed sorted "
133
- f"arrays?"
134
- )
135
-
136
- # read target
137
- target = self.read_source_func(
138
- self.target_files[i], self.data_config.axes
139
- )
140
-
141
- yield sample, target
142
- else:
143
- yield sample, None
144
-
145
- except Exception as e:
146
- logger.error(f"Error reading file {filename}: {e}")
147
-
148
190
  def __iter__(
149
191
  self,
150
- ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]:
192
+ ) -> Generator[Tuple[np.ndarray, ...], None, None]:
151
193
  """
152
194
  Iterate over data source and yield single patch.
153
195
 
@@ -161,7 +203,9 @@ class PathIterableDataset(IterableDataset):
161
203
  ), "Mean and std must be provided"
162
204
 
163
205
  # iterate over files
164
- for sample_input, sample_target in self._iterate_over_files():
206
+ for sample_input, sample_target in _iterate_over_files(
207
+ self.data_config, self.data_files, self.target_files, self.read_source_func
208
+ ):
165
209
  reshaped_sample = reshape_array(sample_input, self.data_config.axes)
166
210
  reshaped_target = (
167
211
  None
@@ -209,9 +253,9 @@ class PathIterableDataset(IterableDataset):
209
253
  Parameters
210
254
  ----------
211
255
  percentage : float, optional
212
- Percentage of files to split up, by default 0.1
256
+ Percentage of files to split up, by default 0.1.
213
257
  minimum_number : int, optional
214
- Minimum number of files to split up, by default 5
258
+ Minimum number of files to split up, by default 5.
215
259
 
216
260
  Returns
217
261
  -------
@@ -275,12 +319,23 @@ class PathIterableDataset(IterableDataset):
275
319
  return dataset
276
320
 
277
321
 
278
- class IterablePredictionDataset(PathIterableDataset):
322
+ class IterablePredictionDataset(IterableDataset):
279
323
  """
280
- Dataset allowing extracting patches w/o loading whole data into memory.
324
+ Prediction dataset.
281
325
 
282
326
  Parameters
283
327
  ----------
328
+ prediction_config : InferenceConfig
329
+ Inference configuration.
330
+ src_files : List[Path]
331
+ List of data files.
332
+ read_source_func : Callable, optional
333
+ Read source function for custom types, by default read_tiff.
334
+ **kwargs : Any
335
+ Additional keyword arguments, unused.
336
+
337
+ Attributes
338
+ ----------
284
339
  data_path : Union[str, Path]
285
340
  Path to the data, must be a directory.
286
341
  axes : str
@@ -300,13 +355,26 @@ class IterablePredictionDataset(PathIterableDataset):
300
355
  read_source_func: Callable = read_tiff,
301
356
  **kwargs: Any,
302
357
  ) -> None:
303
- super().__init__(
304
- data_config=prediction_config,
305
- src_files=src_files,
306
- read_source_func=read_source_func,
307
- )
358
+ """Constructor.
308
359
 
360
+ Parameters
361
+ ----------
362
+ prediction_config : InferenceConfig
363
+ Inference configuration.
364
+ src_files : List[Path]
365
+ List of data files.
366
+ read_source_func : Callable, optional
367
+ Read source function for custom types, by default read_tiff.
368
+ **kwargs : Any
369
+ Additional keyword arguments, unused.
370
+
371
+ Raises
372
+ ------
373
+ ValueError
374
+ If mean and std are not provided in the inference configuration.
375
+ """
309
376
  self.prediction_config = prediction_config
377
+ self.data_files = src_files
310
378
  self.axes = prediction_config.axes
311
379
  self.tile_size = self.prediction_config.tile_size
312
380
  self.tile_overlap = self.prediction_config.tile_overlap
@@ -315,10 +383,21 @@ class IterablePredictionDataset(PathIterableDataset):
315
383
  # tile only if both tile size and overlaps are provided
316
384
  self.tile = self.tile_size is not None and self.tile_overlap is not None
317
385
 
318
- # get tta transforms
319
- self.patch_transform = Compose(
320
- transform_list=prediction_config.transforms,
321
- )
386
+ # check mean and std and create normalize transform
387
+ if self.prediction_config.mean is None or self.prediction_config.std is None:
388
+ raise ValueError("Mean and std must be provided for prediction.")
389
+ else:
390
+ self.mean = self.prediction_config.mean
391
+ self.std = self.prediction_config.std
392
+
393
+ # instantiate normalize transform
394
+ self.patch_transform = Compose(
395
+ transform_list=[
396
+ NormalizeModel(
397
+ mean=prediction_config.mean, std=prediction_config.std
398
+ )
399
+ ],
400
+ )
322
401
 
323
402
  def __iter__(
324
403
  self,
@@ -335,11 +414,19 @@ class IterablePredictionDataset(PathIterableDataset):
335
414
  self.mean is not None and self.std is not None
336
415
  ), "Mean and std must be provided"
337
416
 
338
- for sample, _ in self._iterate_over_files():
417
+ for sample, _ in _iterate_over_files(
418
+ self.prediction_config,
419
+ self.data_files,
420
+ read_source_func=self.read_source_func,
421
+ ):
339
422
  # reshape array
340
423
  reshaped_sample = reshape_array(sample, self.axes)
341
424
 
342
- if self.tile:
425
+ if (
426
+ self.tile
427
+ and self.tile_size is not None
428
+ and self.tile_overlap is not None
429
+ ):
343
430
  # generate patches, return a generator
344
431
  patch_gen = extract_tiles(
345
432
  arr=reshaped_sample,
@@ -1,8 +1,4 @@
1
- """
2
- Tiling submodule.
3
-
4
- These functions are used to tile images into patches or tiles.
5
- """
1
+ """Patching functions."""
6
2
 
7
3
  from pathlib import Path
8
4
  from typing import Callable, List, Tuple, Union
@@ -21,12 +17,25 @@ def prepare_patches_supervised(
21
17
  train_files: List[Path],
22
18
  target_files: List[Path],
23
19
  axes: str,
24
- patch_size: Union[List[int], Tuple[int]],
20
+ patch_size: Union[List[int], Tuple[int, ...]],
25
21
  read_source_func: Callable,
26
22
  ) -> Tuple[np.ndarray, np.ndarray, float, float]:
27
23
  """
28
24
  Iterate over data source and create an array of patches and corresponding targets.
29
25
 
26
+ Parameters
27
+ ----------
28
+ train_files : List[Path]
29
+ List of paths to training data.
30
+ target_files : List[Path]
31
+ List of paths to target data.
32
+ axes : str
33
+ Axes of the data.
34
+ patch_size : Union[List[int], Tuple[int]]
35
+ Size of the patches.
36
+ read_source_func : Callable
37
+ Function to read the data.
38
+
30
39
  Returns
31
40
  -------
32
41
  np.ndarray
@@ -95,13 +104,25 @@ def prepare_patches_unsupervised(
95
104
  patch_size: Union[List[int], Tuple[int]],
96
105
  read_source_func: Callable,
97
106
  ) -> Tuple[np.ndarray, None, float, float]:
98
- """
99
- Iterate over data source and create an array of patches.
107
+ """Iterate over data source and create an array of patches.
108
+
109
+ This method returns the mean and standard deviation of the image.
110
+
111
+ Parameters
112
+ ----------
113
+ train_files : List[Path]
114
+ List of paths to training data.
115
+ axes : str
116
+ Axes of the data.
117
+ patch_size : Union[List[int], Tuple[int]]
118
+ Size of the patches.
119
+ read_source_func : Callable
120
+ Function to read the data.
100
121
 
101
122
  Returns
102
123
  -------
103
- np.ndarray
104
- Array of patches.
124
+ Tuple[np.ndarray, None, float, float]
125
+ Source and target patches, mean and standard deviation.
105
126
  """
106
127
  means, stds, num_samples = 0, 0, 0
107
128
  all_patches = []
@@ -150,10 +171,21 @@ def prepare_patches_supervised_array(
150
171
 
151
172
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
152
173
 
174
+ Parameters
175
+ ----------
176
+ data : np.ndarray
177
+ Input data array.
178
+ axes : str
179
+ Axes of the data.
180
+ data_target : np.ndarray
181
+ Target data array.
182
+ patch_size : Union[List[int], Tuple[int]]
183
+ Size of the patches.
184
+
153
185
  Returns
154
186
  -------
155
- np.ndarray
156
- Array of patches.
187
+ Tuple[np.ndarray, np.ndarray, float, float]
188
+ Source and target patches, mean and standard deviation.
157
189
  """
158
190
  # compute statistics
159
191
  mean = data.mean()
@@ -195,10 +227,19 @@ def prepare_patches_unsupervised_array(
195
227
 
196
228
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
197
229
 
230
+ Parameters
231
+ ----------
232
+ data : np.ndarray
233
+ Input data array.
234
+ axes : str
235
+ Axes of the data.
236
+ patch_size : Union[List[int], Tuple[int]]
237
+ Size of the patches.
238
+
198
239
  Returns
199
240
  -------
200
- np.ndarray
201
- Array of patches.
241
+ Tuple[np.ndarray, None, float, float]
242
+ Source patches, mean and standard deviation.
202
243
  """
203
244
  # calculate mean and std
204
245
  mean = data.mean()
@@ -210,4 +251,4 @@ def prepare_patches_unsupervised_array(
210
251
  # generate patches, return a generator
211
252
  patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
212
253
 
213
- return patches, _, mean, std # TODO inelegant, replace by dataclass?
254
+ return patches, _, mean, std # TODO inelegant, replace by dataclass?
@@ -1,3 +1,5 @@
1
+ """Random patching utilities."""
2
+
1
3
  from typing import Generator, List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -30,6 +32,8 @@ def extract_patches_random(
30
32
  Input image array.
31
33
  patch_size : Tuple[int]
32
34
  Patch sizes in each dimension.
35
+ target : Optional[np.ndarray], optional
36
+ Target array, by default None.
33
37
 
34
38
  Yields
35
39
  ------
@@ -120,10 +124,12 @@ def extract_patches_random_from_chunks(
120
124
  ----------
121
125
  arr : np.ndarray
122
126
  Input image array.
123
- patch_size : Tuple[int]
127
+ patch_size : Union[List[int], Tuple[int, ...]]
124
128
  Patch sizes in each dimension.
125
- chunk_size : Tuple[int]
129
+ chunk_size : Union[List[int], Tuple[int, ...]]
126
130
  Chunk sizes to load from the.
131
+ chunk_limit : Optional[int], optional
132
+ Number of chunks to load, by default None.
127
133
 
128
134
  Yields
129
135
  ------