careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  """Patching functions."""
2
2
 
3
+ from dataclasses import dataclass
3
4
  from pathlib import Path
4
5
  from typing import Callable, List, Tuple, Union
5
6
 
@@ -7,11 +8,38 @@ import numpy as np
7
8
 
8
9
  from ...utils.logging import get_logger
9
10
  from ..dataset_utils import reshape_array
11
+ from ..dataset_utils.running_stats import compute_normalization_stats
10
12
  from .sequential_patching import extract_patches_sequential
11
13
 
12
14
  logger = get_logger(__name__)
13
15
 
14
16
 
17
+ @dataclass
18
+ class Stats:
19
+ """Dataclass to store statistics."""
20
+
21
+ means: Union[np.ndarray, tuple, list, None]
22
+ stds: Union[np.ndarray, tuple, list, None]
23
+
24
+
25
+ @dataclass
26
+ class PatchedOutput:
27
+ """Dataclass to store patches and statistics."""
28
+
29
+ patches: Union[np.ndarray]
30
+ targets: Union[np.ndarray, None]
31
+ image_stats: Stats
32
+ target_stats: Stats
33
+
34
+
35
+ @dataclass
36
+ class StatsOutput:
37
+ """Dataclass to store patches and statistics."""
38
+
39
+ image_stats: Stats
40
+ target_stats: Stats
41
+
42
+
15
43
  # called by in memory dataset
16
44
  def prepare_patches_supervised(
17
45
  train_files: List[Path],
@@ -19,10 +47,12 @@ def prepare_patches_supervised(
19
47
  axes: str,
20
48
  patch_size: Union[List[int], Tuple[int, ...]],
21
49
  read_source_func: Callable,
22
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
50
+ ) -> PatchedOutput:
23
51
  """
24
52
  Iterate over data source and create an array of patches and corresponding targets.
25
53
 
54
+ The lists of Paths should be pre-sorted.
55
+
26
56
  Parameters
27
57
  ----------
28
58
  train_files : List[Path]
@@ -41,9 +71,6 @@ def prepare_patches_supervised(
41
71
  np.ndarray
42
72
  Array of patches.
43
73
  """
44
- train_files.sort()
45
- target_files.sort()
46
-
47
74
  means, stds, num_samples = 0, 0, 0
48
75
  all_patches, all_targets = [], []
49
76
  for train_filename, target_filename in zip(train_files, target_files):
@@ -83,17 +110,18 @@ def prepare_patches_supervised(
83
110
  f"{target_files}."
84
111
  )
85
112
 
86
- result_mean, result_std = means / num_samples, stds / num_samples
113
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
114
+ target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
87
115
 
88
116
  patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
89
117
  target_array: np.ndarray = np.concatenate(all_targets, axis=0)
90
118
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
91
119
 
92
- return (
120
+ return PatchedOutput(
93
121
  patch_array,
94
122
  target_array,
95
- result_mean,
96
- result_std,
123
+ Stats(image_means, image_stds),
124
+ Stats(target_means, target_stds),
97
125
  )
98
126
 
99
127
 
@@ -103,7 +131,7 @@ def prepare_patches_unsupervised(
103
131
  axes: str,
104
132
  patch_size: Union[List[int], Tuple[int]],
105
133
  read_source_func: Callable,
106
- ) -> Tuple[np.ndarray, None, float, float]:
134
+ ) -> PatchedOutput:
107
135
  """Iterate over data source and create an array of patches.
108
136
 
109
137
  This method returns the mean and standard deviation of the image.
@@ -149,12 +177,14 @@ def prepare_patches_unsupervised(
149
177
  if num_samples == 0:
150
178
  raise ValueError(f"No valid samples found in the input data: {train_files}.")
151
179
 
152
- result_mean, result_std = means / num_samples, stds / num_samples
180
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
153
181
 
154
182
  patch_array: np.ndarray = np.concatenate(all_patches)
155
183
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
156
184
 
157
- return patch_array, _, result_mean, result_std # TODO return object?
185
+ return PatchedOutput(
186
+ patch_array, None, Stats(image_means, image_stds), Stats((), ())
187
+ )
158
188
 
159
189
 
160
190
  # called on arrays by in memory dataset
@@ -163,7 +193,7 @@ def prepare_patches_supervised_array(
163
193
  axes: str,
164
194
  data_target: np.ndarray,
165
195
  patch_size: Union[List[int], Tuple[int]],
166
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
196
+ ) -> PatchedOutput:
167
197
  """Iterate over data source and create an array of patches.
168
198
 
169
199
  This method expects an array of shape SC(Z)YX, where S and C can be singleton
@@ -187,14 +217,14 @@ def prepare_patches_supervised_array(
187
217
  Tuple[np.ndarray, np.ndarray, float, float]
188
218
  Source and target patches, mean and standard deviation.
189
219
  """
190
- # compute statistics
191
- mean = data.mean()
192
- std = data.std()
193
-
194
220
  # reshape array
195
221
  reshaped_sample = reshape_array(data, axes)
196
222
  reshaped_target = reshape_array(data_target, axes)
197
223
 
224
+ # compute statistics
225
+ image_means, image_stds = compute_normalization_stats(reshaped_sample)
226
+ target_means, target_stds = compute_normalization_stats(reshaped_target)
227
+
198
228
  # generate patches, return a generator
199
229
  patches, patch_targets = extract_patches_sequential(
200
230
  reshaped_sample, patch_size=patch_size, target=reshaped_target
@@ -205,11 +235,11 @@ def prepare_patches_supervised_array(
205
235
 
206
236
  logger.info(f"Extracted {patches.shape[0]} patches from input array.")
207
237
 
208
- return (
238
+ return PatchedOutput(
209
239
  patches,
210
240
  patch_targets,
211
- mean,
212
- std,
241
+ Stats(image_means, image_stds),
242
+ Stats(target_means, target_stds),
213
243
  )
214
244
 
215
245
 
@@ -218,7 +248,7 @@ def prepare_patches_unsupervised_array(
218
248
  data: np.ndarray,
219
249
  axes: str,
220
250
  patch_size: Union[List[int], Tuple[int]],
221
- ) -> Tuple[np.ndarray, None, float, float]:
251
+ ) -> PatchedOutput:
222
252
  """
223
253
  Iterate over data source and create an array of patches.
224
254
 
@@ -241,14 +271,13 @@ def prepare_patches_unsupervised_array(
241
271
  Tuple[np.ndarray, None, float, float]
242
272
  Source patches, mean and standard deviation.
243
273
  """
244
- # calculate mean and std
245
- mean = data.mean()
246
- std = data.std()
247
-
248
274
  # reshape array
249
275
  reshaped_sample = reshape_array(data, axes)
250
276
 
277
+ # calculate mean and std
278
+ means, stds = compute_normalization_stats(reshaped_sample)
279
+
251
280
  # generate patches, return a generator
252
281
  patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
253
282
 
254
- return patches, _, mean, std # TODO inelegant, replace by dataclass?
283
+ return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
@@ -13,6 +13,7 @@ def extract_patches_random(
13
13
  arr: np.ndarray,
14
14
  patch_size: Union[List[int], Tuple[int, ...]],
15
15
  target: Optional[np.ndarray] = None,
16
+ seed: Optional[int] = None,
16
17
  ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
17
18
  """
18
19
  Generate patches from an array in a random manner.
@@ -34,12 +35,16 @@ def extract_patches_random(
34
35
  Patch sizes in each dimension.
35
36
  target : Optional[np.ndarray], optional
36
37
  Target array, by default None.
38
+ seed : Optional[int], optional
39
+ Random seed, by default None.
37
40
 
38
41
  Yields
39
42
  ------
40
43
  Generator[np.ndarray, None, None]
41
44
  Generator of patches.
42
45
  """
46
+ rng = np.random.default_rng(seed=seed)
47
+
43
48
  is_3d_patch = len(patch_size) == 3
44
49
 
45
50
  # patches sanity check
@@ -48,9 +53,6 @@ def extract_patches_random(
48
53
  # Update patch size to encompass S and C dimensions
49
54
  patch_size = [1, arr.shape[1], *patch_size]
50
55
 
51
- # random generator
52
- rng = np.random.default_rng()
53
-
54
56
  # iterate over the number of samples (S or T)
55
57
  for sample_idx in range(arr.shape[0]):
56
58
  # get sample array
@@ -113,6 +115,7 @@ def extract_patches_random_from_chunks(
113
115
  patch_size: Union[List[int], Tuple[int, ...]],
114
116
  chunk_size: Union[List[int], Tuple[int, ...]],
115
117
  chunk_limit: Optional[int] = None,
118
+ seed: Optional[int] = None,
116
119
  ) -> Generator[np.ndarray, None, None]:
117
120
  """
118
121
  Generate patches from an array in a random manner.
@@ -130,6 +133,8 @@ def extract_patches_random_from_chunks(
130
133
  Chunk sizes to load from the.
131
134
  chunk_limit : Optional[int], optional
132
135
  Number of chunks to load, by default None.
136
+ seed : Optional[int], optional
137
+ Random seed, by default None.
133
138
 
134
139
  Yields
135
140
  ------
@@ -141,7 +146,7 @@ def extract_patches_random_from_chunks(
141
146
  # Patches sanity check
142
147
  validate_patch_dimensions(arr, patch_size, is_3d_patch)
143
148
 
144
- rng = np.random.default_rng()
149
+ rng = np.random.default_rng(seed=seed)
145
150
  num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
146
151
 
147
152
  # Iterate over num chunks in the array
@@ -45,18 +45,20 @@ def validate_patch_dimensions(
45
45
  if len(patch_size) != len(arr.shape[2:]):
46
46
  raise ValueError(
47
47
  f"There must be a patch size for each spatial dimensions "
48
- f"(got {patch_size} patches for dims {arr.shape})."
48
+ f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
49
49
  )
50
50
 
51
51
  # Sanity checks on patch sizes versus array dimension
52
52
  if is_3d_patch and patch_size[0] > arr.shape[-3]:
53
53
  raise ValueError(
54
54
  f"Z patch size is inconsistent with image shape "
55
- f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
55
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
56
+ f"order."
56
57
  )
57
58
 
58
59
  if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
59
60
  raise ValueError(
60
61
  f"At least one of YX patch dimensions is larger than the corresponding "
61
- f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
62
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
63
+ f"Check the axes order."
62
64
  )
@@ -0,0 +1,10 @@
1
+ """Tiling functions."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "extract_tiles",
6
+ "collate_tiles",
7
+ ]
8
+
9
+ from .collate_tiles import collate_tiles
10
+ from .tiled_patching import extract_tiles
@@ -0,0 +1,33 @@
1
+ """Collate function for tiling."""
2
+
3
+ from typing import Any, List, Tuple
4
+
5
+ import numpy as np
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from careamics.config.tile_information import TileInformation
9
+
10
+
11
+ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
12
+ """
13
+ Collate tiles received from CAREamics prediction dataloader.
14
+
15
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
16
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
17
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
18
+ stitch coordinates.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
23
+ Batch of tiles.
24
+
25
+ Returns
26
+ -------
27
+ Any
28
+ Collated batch.
29
+ """
30
+ new_batch = [tile for tile, _ in batch]
31
+ tiles_batch = [tile_info for _, tile_info in batch]
32
+
33
+ return default_collate(new_batch), tiles_batch
@@ -84,15 +84,15 @@ def extract_tiles(
84
84
  tile_size: Union[List[int], Tuple[int, ...]],
85
85
  overlaps: Union[List[int], Tuple[int, ...]],
86
86
  ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
87
- """
88
- Generate tiles from the input array with specified overlap.
87
+ """Generate tiles from the input array with specified overlap.
89
88
 
90
89
  The tiles cover the whole array. The method returns a generator that yields
91
90
  tuples of array and tile information, the latter includes whether
92
91
  the tile is the last one, the coordinates of the overlap crop, and the coordinates
93
92
  of the stitched tile.
94
93
 
95
- The array has shape C(Z)YX, where C can be a singleton.
94
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
95
+ where C can be a singleton.
96
96
 
97
97
  Parameters
98
98
  ----------
@@ -155,10 +155,10 @@ def extract_tiles(
155
155
  # create tile information
156
156
  tile_info = TileInformation(
157
157
  array_shape=sample.squeeze().shape,
158
- tiled=True,
159
158
  last_tile=last_tile,
160
159
  overlap_crop_coords=overlap_crop_coords,
161
160
  stitch_coords=stitch_coords,
161
+ sample_id=sample_idx,
162
162
  )
163
163
 
164
164
  yield tile, tile_info
@@ -583,12 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
583
583
  >>> my_array = np.arange(256).reshape(16, 16)
584
584
  >>> my_transforms = [
585
585
  ... {
586
- ... "name": SupportedTransform.NORMALIZE.value,
587
- ... "mean": 0,
588
- ... "std": 1,
589
- ... },
590
- ... {
591
- ... "name": SupportedTransform.N2V_MANIPULATE.value,
586
+ ... "name": SupportedTransform.XY_FLIP.value,
592
587
  ... }
593
588
  ... ]
594
589
  >>> data_module = TrainingDataWrapper(
@@ -148,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
148
148
  Any
149
149
  Model output.
150
150
  """
151
- x, *aux = batch
151
+ if self._trainer.datamodule.tiled:
152
+ x, *aux = batch
153
+ else:
154
+ x = batch
155
+ aux = []
152
156
 
153
157
  # apply test-time augmentation if available
154
158
  # TODO: probably wont work with batch size > 1
155
159
  if self._trainer.datamodule.prediction_config.tta_transforms:
156
160
  tta = ImageRestorationTTA()
157
- augmented_batch = tta.forward(batch[0]) # list of augmented tensors
161
+ augmented_batch = tta.forward(x) # list of augmented tensors
158
162
  augmented_output = []
159
163
  for augmented in augmented_batch:
160
164
  augmented_pred = self.model(augmented)
@@ -165,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
165
169
 
166
170
  # Denormalize the output
167
171
  denorm = Denormalize(
168
- mean=self._trainer.datamodule.predict_dataset.mean,
169
- std=self._trainer.datamodule.predict_dataset.std,
172
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
173
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
170
174
  )
171
- denormalized_output, _ = denorm(patch=output)
175
+ denormalized_output = denorm(patch=output.cpu().numpy())
172
176
 
173
- if len(aux) > 0:
174
- return denormalized_output, aux
177
+ if len(aux) > 0: # aux can be tiling information
178
+ return denormalized_output, *aux
175
179
  else:
176
180
  return denormalized_output
177
181
 
@@ -1,68 +1,37 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
8
  from torch.utils.data import DataLoader
9
- from torch.utils.data.dataloader import default_collate
10
9
 
11
10
  from careamics.config import InferenceConfig
12
11
  from careamics.config.support import SupportedData
13
- from careamics.config.tile_information import TileInformation
12
+ from careamics.dataset import (
13
+ InMemoryPredDataset,
14
+ InMemoryTiledPredDataset,
15
+ IterablePredDataset,
16
+ IterableTiledPredDataset,
17
+ )
14
18
  from careamics.dataset.dataset_utils import (
15
19
  get_read_func,
16
20
  list_files,
17
21
  )
18
- from careamics.dataset.in_memory_dataset import (
19
- InMemoryPredictionDataset,
20
- )
21
- from careamics.dataset.iterable_dataset import (
22
- IterablePredictionDataset,
23
- )
22
+ from careamics.dataset.tiling.collate_tiles import collate_tiles
24
23
  from careamics.utils import get_logger
25
24
 
26
- PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
25
+ PredictDatasetType = Union[
26
+ InMemoryPredDataset,
27
+ InMemoryTiledPredDataset,
28
+ IterablePredDataset,
29
+ IterableTiledPredDataset,
30
+ ]
27
31
 
28
32
  logger = get_logger(__name__)
29
33
 
30
34
 
31
- def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
32
- """
33
- Collate tiles received from CAREamics prediction dataloader.
34
-
35
- CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
36
- case of non-tiled data, this function will return the arrays. In case of tiled data,
37
- it will return the arrays, the last tile flag, the overlap crop coordinates and the
38
- stitch coordinates.
39
-
40
- Parameters
41
- ----------
42
- batch : List[Tuple[np.ndarray, TileInformation], ...]
43
- Batch of tiles.
44
-
45
- Returns
46
- -------
47
- Any
48
- Collated batch.
49
- """
50
- first_tile_info: TileInformation = batch[0][1]
51
- # if not tiled, then return arrays
52
- if not first_tile_info.tiled:
53
- arrays, _ = zip(*batch)
54
-
55
- return default_collate(arrays)
56
- # else we explicit the last_tile flag and coordinates
57
- else:
58
- new_batch = [
59
- (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
60
- for tile, t in batch
61
- ]
62
-
63
- return default_collate(new_batch)
64
-
65
-
66
35
  class CAREamicsPredictData(L.LightningDataModule):
67
36
  """
68
37
  CAREamics Lightning prediction data module.
@@ -182,6 +151,9 @@ class CAREamicsPredictData(L.LightningDataModule):
182
151
  self.tile_size = pred_config.tile_size
183
152
  self.tile_overlap = pred_config.tile_overlap
184
153
 
154
+ # check if it is tiled
155
+ self.tiled = self.tile_size is not None and self.tile_overlap is not None
156
+
185
157
  # read source function
186
158
  if pred_config.data_type == SupportedData.CUSTOM:
187
159
  # mypy check
@@ -212,17 +184,29 @@ class CAREamicsPredictData(L.LightningDataModule):
212
184
  """
213
185
  # if numpy array
214
186
  if self.data_type == SupportedData.ARRAY:
215
- # prediction dataset
216
- self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
217
- prediction_config=self.prediction_config,
218
- inputs=self.pred_data,
219
- )
187
+ if self.tiled:
188
+ self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
189
+ prediction_config=self.prediction_config,
190
+ inputs=self.pred_data,
191
+ )
192
+ else:
193
+ self.predict_dataset = InMemoryPredDataset(
194
+ prediction_config=self.prediction_config,
195
+ inputs=self.pred_data,
196
+ )
220
197
  else:
221
- self.predict_dataset = IterablePredictionDataset(
222
- prediction_config=self.prediction_config,
223
- src_files=self.pred_files,
224
- read_source_func=self.read_source_func,
225
- )
198
+ if self.tiled:
199
+ self.predict_dataset = IterableTiledPredDataset(
200
+ prediction_config=self.prediction_config,
201
+ src_files=self.pred_files,
202
+ read_source_func=self.read_source_func,
203
+ )
204
+ else:
205
+ self.predict_dataset = IterablePredDataset(
206
+ prediction_config=self.prediction_config,
207
+ src_files=self.pred_files,
208
+ read_source_func=self.read_source_func,
209
+ )
226
210
 
227
211
  def predict_dataloader(self) -> DataLoader:
228
212
  """
@@ -236,7 +220,7 @@ class CAREamicsPredictData(L.LightningDataModule):
236
220
  return DataLoader(
237
221
  self.predict_dataset,
238
222
  batch_size=self.batch_size,
239
- collate_fn=_collate_tiles,
223
+ collate_fn=collate_tiles if self.tiled else None,
240
224
  **self.dataloader_params,
241
225
  ) # TODO check workers are used
242
226
 
@@ -287,12 +271,10 @@ class PredictDataWrapper(CAREamicsPredictData):
287
271
  Prediction data.
288
272
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
289
273
  Data type, see `SupportedData` for available options.
290
- mean : float
291
- Mean value for normalization, only used if Normalization is defined in the
292
- transforms.
293
- std : float
294
- Standard deviation value for normalization, only used if Normalization is
295
- defined in the transform.
274
+ image_means : list of float
275
+ Mean values for normalization, only used if Normalization is defined.
276
+ image_stds : list of float
277
+ Std values for normalization, only used if Normalization is defined.
296
278
  tile_size : Tuple[int, ...]
297
279
  Tile size, 2D or 3D tile size.
298
280
  tile_overlap : Tuple[int, ...]
@@ -316,8 +298,8 @@ class PredictDataWrapper(CAREamicsPredictData):
316
298
  self,
317
299
  pred_data: Union[str, Path, np.ndarray],
318
300
  data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
319
- mean: float,
320
- std: float,
301
+ image_means=list[float],
302
+ image_stds=list[float],
321
303
  tile_size: Optional[Tuple[int, ...]] = None,
322
304
  tile_overlap: Optional[Tuple[int, ...]] = None,
323
305
  axes: str = "YX",
@@ -336,12 +318,10 @@ class PredictDataWrapper(CAREamicsPredictData):
336
318
  Prediction data.
337
319
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
338
320
  Data type, see `SupportedData` for available options.
339
- mean : float
340
- Mean value for normalization, only used if Normalization is defined in the
341
- transforms.
342
- std : float
343
- Standard deviation value for normalization, only used if Normalization is
344
- defined in the transform.
321
+ image_means : list of float
322
+ Mean values for normalization, only used if Normalization is defined.
323
+ image_stds : list of float
324
+ Std values for normalization, only used if Normalization is defined.
345
325
  tile_size : List[int]
346
326
  Tile size, 2D or 3D tile size.
347
327
  tile_overlap : List[int]
@@ -367,8 +347,8 @@ class PredictDataWrapper(CAREamicsPredictData):
367
347
  "tile_size": tile_size,
368
348
  "tile_overlap": tile_overlap,
369
349
  "axes": axes,
370
- "mean": mean,
371
- "std": std,
350
+ "image_means": image_means,
351
+ "image_stds": image_stds,
372
352
  "tta": tta_transforms,
373
353
  "batch_size": batch_size,
374
354
  "transforms": [],
File without changes