careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ """Iterable dataset used to load data file by file."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import copy
@@ -7,26 +9,98 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
7
9
  import numpy as np
8
10
  from torch.utils.data import IterableDataset, get_worker_info
9
11
 
12
+ from careamics.transforms import Compose
13
+
10
14
  from ..config import DataConfig, InferenceConfig
11
15
  from ..config.tile_information import TileInformation
16
+ from ..config.transformations import NormalizeModel
12
17
  from ..utils.logging import get_logger
13
18
  from .dataset_utils import read_tiff, reshape_array
14
- from .patching import (
15
- get_patch_transform,
16
- )
17
19
  from .patching.random_patching import extract_patches_random
18
20
  from .patching.tiled_patching import extract_tiles
19
21
 
20
22
  logger = get_logger(__name__)
21
23
 
22
24
 
25
+ def _iterate_over_files(
26
+ data_config: Union[DataConfig, InferenceConfig],
27
+ data_files: List[Path],
28
+ target_files: Optional[List[Path]] = None,
29
+ read_source_func: Callable = read_tiff,
30
+ ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
31
+ """
32
+ Iterate over data source and yield whole image.
33
+
34
+ Parameters
35
+ ----------
36
+ data_config : Union[DataConfig, InferenceConfig]
37
+ Data configuration.
38
+ data_files : List[Path]
39
+ List of data files.
40
+ target_files : Optional[List[Path]]
41
+ List of target files, by default None.
42
+ read_source_func : Optional[Callable]
43
+ Function to read the source, by default read_tiff.
44
+
45
+ Yields
46
+ ------
47
+ np.ndarray
48
+ Image.
49
+ """
50
+ # When num_workers > 0, each worker process will have a different copy of the
51
+ # dataset object
52
+ # Configuring each copy independently to avoid having duplicate data returned
53
+ # from the workers
54
+ worker_info = get_worker_info()
55
+ worker_id = worker_info.id if worker_info is not None else 0
56
+ num_workers = worker_info.num_workers if worker_info is not None else 1
57
+
58
+ # iterate over the files
59
+ for i, filename in enumerate(data_files):
60
+ # retrieve file corresponding to the worker id
61
+ if i % num_workers == worker_id:
62
+ try:
63
+ # read data
64
+ sample = read_source_func(filename, data_config.axes)
65
+
66
+ # read target, if available
67
+ if target_files is not None:
68
+ if filename.name != target_files[i].name:
69
+ raise ValueError(
70
+ f"File {filename} does not match target file "
71
+ f"{target_files[i]}. Have you passed sorted "
72
+ f"arrays?"
73
+ )
74
+
75
+ # read target
76
+ target = read_source_func(target_files[i], data_config.axes)
77
+
78
+ yield sample, target
79
+ else:
80
+ yield sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
84
+
85
+
23
86
  class PathIterableDataset(IterableDataset):
24
87
  """
25
88
  Dataset allowing extracting patches w/o loading whole data into memory.
26
89
 
27
90
  Parameters
28
91
  ----------
29
- data_path : Union[str, Path]
92
+ data_config : DataConfig
93
+ Data configuration.
94
+ src_files : List[Path]
95
+ List of data files.
96
+ target_files : Optional[List[Path]], optional
97
+ Optional list of target files, by default None.
98
+ read_source_func : Callable, optional
99
+ Read source function for custom types, by default read_tiff.
100
+
101
+ Attributes
102
+ ----------
103
+ data_path : List[Path]
30
104
  Path to the data, must be a directory.
31
105
  axes : str
32
106
  Description of axes in format STCZYX.
@@ -46,11 +120,24 @@ class PathIterableDataset(IterableDataset):
46
120
 
47
121
  def __init__(
48
122
  self,
49
- data_config: Union[DataConfig, InferenceConfig],
123
+ data_config: DataConfig,
50
124
  src_files: List[Path],
51
125
  target_files: Optional[List[Path]] = None,
52
126
  read_source_func: Callable = read_tiff,
53
127
  ) -> None:
128
+ """Constructors.
129
+
130
+ Parameters
131
+ ----------
132
+ data_config : DataConfig
133
+ Data configuration.
134
+ src_files : List[Path]
135
+ List of data files.
136
+ target_files : Optional[List[Path]], optional
137
+ Optional list of target files, by default None.
138
+ read_source_func : Callable, optional
139
+ Read source function for custom types, by default read_tiff.
140
+ """
54
141
  self.data_config = data_config
55
142
  self.data_files = src_files
56
143
  self.target_files = target_files
@@ -61,26 +148,15 @@ class PathIterableDataset(IterableDataset):
61
148
  if not data_config.mean or not data_config.std:
62
149
  self.mean, self.std = self._calculate_mean_and_std()
63
150
 
64
- # if the transforms are not an instance of Compose
65
- # Check if the data_config is an instance of DataModel or InferenceModel
66
- # isinstance isn't working properly here
67
- if hasattr(data_config, "has_transform_list"):
68
- if data_config.has_transform_list():
69
- # update mean and std in configuration
70
- # the object is mutable and should then be recorded in the CAREamist
71
- data_config.set_mean_and_std(self.mean, self.std)
72
- else:
73
- data_config.set_mean_and_std(self.mean, self.std)
74
-
151
+ # update mean and std in configuration
152
+ # the object is mutable and should then be recorded in the CAREamist
153
+ data_config.set_mean_and_std(self.mean, self.std)
75
154
  else:
76
155
  self.mean = data_config.mean
77
156
  self.std = data_config.std
78
157
 
79
158
  # get transforms
80
- self.patch_transform = get_patch_transform(
81
- patch_transforms=data_config.transforms,
82
- with_target=target_files is not None,
83
- )
159
+ self.patch_transform = Compose(transform_list=data_config.transforms)
84
160
 
85
161
  def _calculate_mean_and_std(self) -> Tuple[float, float]:
86
162
  """
@@ -94,7 +170,9 @@ class PathIterableDataset(IterableDataset):
94
170
  means, stds = 0, 0
95
171
  num_samples = 0
96
172
 
97
- for sample, _ in self._iterate_over_files():
173
+ for sample, _ in _iterate_over_files(
174
+ self.data_config, self.data_files, self.target_files, self.read_source_func
175
+ ):
98
176
  means += sample.mean()
99
177
  stds += sample.std()
100
178
  num_samples += 1
@@ -109,57 +187,9 @@ class PathIterableDataset(IterableDataset):
109
187
  logger.info(f"Mean: {result_mean}, std: {result_std}")
110
188
  return result_mean, result_std
111
189
 
112
- def _iterate_over_files(
113
- self,
114
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
115
- """
116
- Iterate over data source and yield whole image.
117
-
118
- Yields
119
- ------
120
- np.ndarray
121
- Image.
122
- """
123
- # When num_workers > 0, each worker process will have a different copy of the
124
- # dataset object
125
- # Configuring each copy independently to avoid having duplicate data returned
126
- # from the workers
127
- worker_info = get_worker_info()
128
- worker_id = worker_info.id if worker_info is not None else 0
129
- num_workers = worker_info.num_workers if worker_info is not None else 1
130
-
131
- # iterate over the files
132
- for i, filename in enumerate(self.data_files):
133
- # retrieve file corresponding to the worker id
134
- if i % num_workers == worker_id:
135
- try:
136
- # read data
137
- sample = self.read_source_func(filename, self.data_config.axes)
138
-
139
- # read target, if available
140
- if self.target_files is not None:
141
- if filename.name != self.target_files[i].name:
142
- raise ValueError(
143
- f"File {filename} does not match target file "
144
- f"{self.target_files[i]}. Have you passed sorted "
145
- f"arrays?"
146
- )
147
-
148
- # read target
149
- target = self.read_source_func(
150
- self.target_files[i], self.data_config.axes
151
- )
152
-
153
- yield sample, target
154
- else:
155
- yield sample, None
156
-
157
- except Exception as e:
158
- logger.error(f"Error reading file {filename}: {e}")
159
-
160
190
  def __iter__(
161
191
  self,
162
- ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]:
192
+ ) -> Generator[Tuple[np.ndarray, ...], None, None]:
163
193
  """
164
194
  Iterate over data source and yield single patch.
165
195
 
@@ -173,7 +203,9 @@ class PathIterableDataset(IterableDataset):
173
203
  ), "Mean and std must be provided"
174
204
 
175
205
  # iterate over files
176
- for sample_input, sample_target in self._iterate_over_files():
206
+ for sample_input, sample_target in _iterate_over_files(
207
+ self.data_config, self.data_files, self.target_files, self.read_source_func
208
+ ):
177
209
  reshaped_sample = reshape_array(sample_input, self.data_config.axes)
178
210
  reshaped_target = (
179
211
  None
@@ -192,49 +224,10 @@ class PathIterableDataset(IterableDataset):
192
224
  # or (patch, None) only if no target is available
193
225
  # patch is of dimensions (C)ZYX
194
226
  for patch_data in patches:
195
- # if there is a target
196
- if self.target_files is not None:
197
- # Albumentations expects the channel dimension to be last
198
- # Taking the first element because patch_data can include target
199
- c_patch = np.moveaxis(patch_data[0], 0, -1)
200
- c_target = np.moveaxis(patch_data[1], 0, -1)
201
-
202
- # apply the transform to the patch and the target
203
- transformed = self.patch_transform(
204
- image=c_patch,
205
- target=c_target,
206
- )
207
-
208
- # move the axes back to the original position
209
- c_patch = np.moveaxis(transformed["image"], -1, 0)
210
- c_target = np.moveaxis(transformed["target"], -1, 0)
211
-
212
- yield (c_patch, c_target)
213
- elif self.data_config.has_n2v_manipulate():
214
- # Albumentations expects the channel dimension to be last
215
- # Taking the first element because patch_data can include target
216
- patch = np.moveaxis(patch_data[0], 0, -1)
217
-
218
- # apply transform
219
- transformed = self.patch_transform(image=patch)
220
-
221
- # retrieve the output of ManipulateN2V
222
- results = transformed["image"]
223
- masked_patch: np.ndarray = results[0]
224
- original_patch: np.ndarray = results[1]
225
- mask: np.ndarray = results[2]
226
-
227
- # move C axes back
228
- masked_patch = np.moveaxis(masked_patch, -1, 0)
229
- original_patch = np.moveaxis(original_patch, -1, 0)
230
- mask = np.moveaxis(mask, -1, 0)
231
-
232
- yield (masked_patch, original_patch, mask)
233
- else:
234
- raise ValueError(
235
- "Something went wrong! Not target file (no supervised "
236
- "training) and no N2V transform (no n2v training either)."
237
- )
227
+ yield self.patch_transform(
228
+ patch=patch_data[0],
229
+ target=patch_data[1],
230
+ )
238
231
 
239
232
  def get_number_of_files(self) -> int:
240
233
  """
@@ -260,9 +253,9 @@ class PathIterableDataset(IterableDataset):
260
253
  Parameters
261
254
  ----------
262
255
  percentage : float, optional
263
- Percentage of files to split up, by default 0.1
256
+ Percentage of files to split up, by default 0.1.
264
257
  minimum_number : int, optional
265
- Minimum number of files to split up, by default 5
258
+ Minimum number of files to split up, by default 5.
266
259
 
267
260
  Returns
268
261
  -------
@@ -326,12 +319,23 @@ class PathIterableDataset(IterableDataset):
326
319
  return dataset
327
320
 
328
321
 
329
- class IterablePredictionDataset(PathIterableDataset):
322
+ class IterablePredictionDataset(IterableDataset):
330
323
  """
331
- Dataset allowing extracting patches w/o loading whole data into memory.
324
+ Prediction dataset.
332
325
 
333
326
  Parameters
334
327
  ----------
328
+ prediction_config : InferenceConfig
329
+ Inference configuration.
330
+ src_files : List[Path]
331
+ List of data files.
332
+ read_source_func : Callable, optional
333
+ Read source function for custom types, by default read_tiff.
334
+ **kwargs : Any
335
+ Additional keyword arguments, unused.
336
+
337
+ Attributes
338
+ ----------
335
339
  data_path : Union[str, Path]
336
340
  Path to the data, must be a directory.
337
341
  axes : str
@@ -351,13 +355,26 @@ class IterablePredictionDataset(PathIterableDataset):
351
355
  read_source_func: Callable = read_tiff,
352
356
  **kwargs: Any,
353
357
  ) -> None:
354
- super().__init__(
355
- data_config=prediction_config,
356
- src_files=src_files,
357
- read_source_func=read_source_func,
358
- )
358
+ """Constructor.
359
359
 
360
+ Parameters
361
+ ----------
362
+ prediction_config : InferenceConfig
363
+ Inference configuration.
364
+ src_files : List[Path]
365
+ List of data files.
366
+ read_source_func : Callable, optional
367
+ Read source function for custom types, by default read_tiff.
368
+ **kwargs : Any
369
+ Additional keyword arguments, unused.
370
+
371
+ Raises
372
+ ------
373
+ ValueError
374
+ If mean and std are not provided in the inference configuration.
375
+ """
360
376
  self.prediction_config = prediction_config
377
+ self.data_files = src_files
361
378
  self.axes = prediction_config.axes
362
379
  self.tile_size = self.prediction_config.tile_size
363
380
  self.tile_overlap = self.prediction_config.tile_overlap
@@ -366,11 +383,21 @@ class IterablePredictionDataset(PathIterableDataset):
366
383
  # tile only if both tile size and overlaps are provided
367
384
  self.tile = self.tile_size is not None and self.tile_overlap is not None
368
385
 
369
- # get tta transforms
370
- self.patch_transform = get_patch_transform(
371
- patch_transforms=prediction_config.transforms,
372
- with_target=False,
373
- )
386
+ # check mean and std and create normalize transform
387
+ if self.prediction_config.mean is None or self.prediction_config.std is None:
388
+ raise ValueError("Mean and std must be provided for prediction.")
389
+ else:
390
+ self.mean = self.prediction_config.mean
391
+ self.std = self.prediction_config.std
392
+
393
+ # instantiate normalize transform
394
+ self.patch_transform = Compose(
395
+ transform_list=[
396
+ NormalizeModel(
397
+ mean=prediction_config.mean, std=prediction_config.std
398
+ )
399
+ ],
400
+ )
374
401
 
375
402
  def __iter__(
376
403
  self,
@@ -387,11 +414,19 @@ class IterablePredictionDataset(PathIterableDataset):
387
414
  self.mean is not None and self.std is not None
388
415
  ), "Mean and std must be provided"
389
416
 
390
- for sample, _ in self._iterate_over_files():
417
+ for sample, _ in _iterate_over_files(
418
+ self.prediction_config,
419
+ self.data_files,
420
+ read_source_func=self.read_source_func,
421
+ ):
391
422
  # reshape array
392
423
  reshaped_sample = reshape_array(sample, self.axes)
393
424
 
394
- if self.tile:
425
+ if (
426
+ self.tile
427
+ and self.tile_size is not None
428
+ and self.tile_overlap is not None
429
+ ):
395
430
  # generate patches, return a generator
396
431
  patch_gen = extract_tiles(
397
432
  arr=reshaped_sample,
@@ -408,9 +443,6 @@ class IterablePredictionDataset(PathIterableDataset):
408
443
 
409
444
  # apply transform to patches
410
445
  for patch_array, tile_info in patch_gen:
411
- # albumentations expects the channel dimension to be last
412
- patch = np.moveaxis(patch_array, 0, -1)
413
- transformed_patch = self.patch_transform(image=patch)
414
- transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
446
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
415
447
 
416
448
  yield transformed_patch, tile_info
@@ -1,8 +1 @@
1
1
  """Patching and tiling functions."""
2
-
3
-
4
- __all__ = [
5
- "get_patch_transform",
6
- ]
7
-
8
- from .patch_transform import get_patch_transform
@@ -1,8 +1,5 @@
1
- """
2
- Tiling submodule.
1
+ """Patching functions."""
3
2
 
4
- These functions are used to tile images into patches or tiles.
5
- """
6
3
  from pathlib import Path
7
4
  from typing import Callable, List, Tuple, Union
8
5
 
@@ -20,12 +17,25 @@ def prepare_patches_supervised(
20
17
  train_files: List[Path],
21
18
  target_files: List[Path],
22
19
  axes: str,
23
- patch_size: Union[List[int], Tuple[int]],
20
+ patch_size: Union[List[int], Tuple[int, ...]],
24
21
  read_source_func: Callable,
25
22
  ) -> Tuple[np.ndarray, np.ndarray, float, float]:
26
23
  """
27
24
  Iterate over data source and create an array of patches and corresponding targets.
28
25
 
26
+ Parameters
27
+ ----------
28
+ train_files : List[Path]
29
+ List of paths to training data.
30
+ target_files : List[Path]
31
+ List of paths to target data.
32
+ axes : str
33
+ Axes of the data.
34
+ patch_size : Union[List[int], Tuple[int]]
35
+ Size of the patches.
36
+ read_source_func : Callable
37
+ Function to read the data.
38
+
29
39
  Returns
30
40
  -------
31
41
  np.ndarray
@@ -94,13 +104,25 @@ def prepare_patches_unsupervised(
94
104
  patch_size: Union[List[int], Tuple[int]],
95
105
  read_source_func: Callable,
96
106
  ) -> Tuple[np.ndarray, None, float, float]:
97
- """
98
- Iterate over data source and create an array of patches.
107
+ """Iterate over data source and create an array of patches.
108
+
109
+ This method returns the mean and standard deviation of the image.
110
+
111
+ Parameters
112
+ ----------
113
+ train_files : List[Path]
114
+ List of paths to training data.
115
+ axes : str
116
+ Axes of the data.
117
+ patch_size : Union[List[int], Tuple[int]]
118
+ Size of the patches.
119
+ read_source_func : Callable
120
+ Function to read the data.
99
121
 
100
122
  Returns
101
123
  -------
102
- np.ndarray
103
- Array of patches.
124
+ Tuple[np.ndarray, None, float, float]
125
+ Source and target patches, mean and standard deviation.
104
126
  """
105
127
  means, stds, num_samples = 0, 0, 0
106
128
  all_patches = []
@@ -149,10 +171,21 @@ def prepare_patches_supervised_array(
149
171
 
150
172
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
151
173
 
174
+ Parameters
175
+ ----------
176
+ data : np.ndarray
177
+ Input data array.
178
+ axes : str
179
+ Axes of the data.
180
+ data_target : np.ndarray
181
+ Target data array.
182
+ patch_size : Union[List[int], Tuple[int]]
183
+ Size of the patches.
184
+
152
185
  Returns
153
186
  -------
154
- np.ndarray
155
- Array of patches.
187
+ Tuple[np.ndarray, np.ndarray, float, float]
188
+ Source and target patches, mean and standard deviation.
156
189
  """
157
190
  # compute statistics
158
191
  mean = data.mean()
@@ -194,10 +227,19 @@ def prepare_patches_unsupervised_array(
194
227
 
195
228
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
196
229
 
230
+ Parameters
231
+ ----------
232
+ data : np.ndarray
233
+ Input data array.
234
+ axes : str
235
+ Axes of the data.
236
+ patch_size : Union[List[int], Tuple[int]]
237
+ Size of the patches.
238
+
197
239
  Returns
198
240
  -------
199
- np.ndarray
200
- Array of patches.
241
+ Tuple[np.ndarray, None, float, float]
242
+ Source patches, mean and standard deviation.
201
243
  """
202
244
  # calculate mean and std
203
245
  mean = data.mean()
@@ -209,4 +251,4 @@ def prepare_patches_unsupervised_array(
209
251
  # generate patches, return a generator
210
252
  patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
211
253
 
212
- return patches, _, mean, std # TODO inelegant, replace by dataclass?
254
+ return patches, _, mean, std # TODO inelegant, replace by dataclass?
@@ -1,3 +1,5 @@
1
+ """Random patching utilities."""
2
+
1
3
  from typing import Generator, List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -30,6 +32,8 @@ def extract_patches_random(
30
32
  Input image array.
31
33
  patch_size : Tuple[int]
32
34
  Patch sizes in each dimension.
35
+ target : Optional[np.ndarray], optional
36
+ Target array, by default None.
33
37
 
34
38
  Yields
35
39
  ------
@@ -120,10 +124,12 @@ def extract_patches_random_from_chunks(
120
124
  ----------
121
125
  arr : np.ndarray
122
126
  Input image array.
123
- patch_size : Tuple[int]
127
+ patch_size : Union[List[int], Tuple[int, ...]]
124
128
  Patch sizes in each dimension.
125
- chunk_size : Tuple[int]
129
+ chunk_size : Union[List[int], Tuple[int, ...]]
126
130
  Chunk sizes to load from the.
131
+ chunk_limit : Optional[int], optional
132
+ Number of chunks to load, by default None.
127
133
 
128
134
  Yields
129
135
  ------
@@ -1,3 +1,5 @@
1
+ """Sequential patching functions."""
2
+
1
3
  from typing import List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
14
16
 
15
17
  Parameters
16
18
  ----------
17
- arr : Tuple[int, ...]
19
+ arr_shape : Tuple[int, ...]
18
20
  Shape of the input array.
19
- patch_sizes : Tuple[int]
21
+ patch_sizes : Union[List[int], Tuple[int, ...]
20
22
  Shape of the patches.
21
23
 
22
24
  Returns
23
25
  -------
24
- Tuple[int]
26
+ Tuple[int, ...]
25
27
  Number of patches in each dimension.
26
28
  """
27
29
  if len(arr_shape) != len(patch_sizes):
@@ -55,14 +57,14 @@ def _compute_overlap(
55
57
 
56
58
  Parameters
57
59
  ----------
58
- arr : Tuple[int, ...]
60
+ arr_shape : Tuple[int, ...]
59
61
  Input array shape.
60
- patch_sizes : Tuple[int]
62
+ patch_sizes : Union[List[int], Tuple[int, ...]]
61
63
  Size of the patches.
62
64
 
63
65
  Returns
64
66
  -------
65
- Tuple[int]
67
+ Tuple[int, ...]
66
68
  Overlap between patches in each dimension.
67
69
  """
68
70
  n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
@@ -123,6 +125,8 @@ def _compute_patch_views(
123
125
  Steps between views.
124
126
  output_shape : Tuple[int]
125
127
  Shape of the output array.
128
+ target : Optional[np.ndarray], optional
129
+ Target array, by default None.
126
130
 
127
131
  Returns
128
132
  -------
@@ -135,15 +139,12 @@ def _compute_patch_views(
135
139
  arr = np.stack([arr, target], axis=0)
136
140
  window_shape = [arr.shape[0], *window_shape]
137
141
  step = (arr.shape[0], *step)
138
- output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]]
142
+ output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
139
143
 
140
144
  patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
141
145
  *output_shape
142
146
  )
143
- if target is not None:
144
- rng.shuffle(patches, axis=1)
145
- else:
146
- rng.shuffle(patches, axis=0)
147
+ rng.shuffle(patches, axis=0)
147
148
  return patches
148
149
 
149
150
 
@@ -164,11 +165,13 @@ def extract_patches_sequential(
164
165
  Input image array.
165
166
  patch_size : Tuple[int]
166
167
  Patch sizes in each dimension.
168
+ target : Optional[np.ndarray], optional
169
+ Target array, by default None.
167
170
 
168
171
  Returns
169
172
  -------
170
- Generator[Tuple[np.ndarray, ...], None, None]
171
- Generator of patches.
173
+ Tuple[np.ndarray, Optional[np.ndarray]]
174
+ Patches.
172
175
  """
173
176
  is_3d_patch = len(patch_size) == 3
174
177
 
@@ -201,6 +204,9 @@ def extract_patches_sequential(
201
204
 
202
205
  if target is not None:
203
206
  # target was concatenated to patches in _compute_reshaped_view
204
- return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view?
207
+ return (
208
+ patches[:, 0, ...],
209
+ patches[:, 1, ...],
210
+ ) # TODO in _compute_reshaped_view?
205
211
  else:
206
212
  return patches, None