careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -4,25 +4,24 @@ from __future__ import annotations
4
4
 
5
5
  import copy
6
6
  from pathlib import Path
7
- from typing import Any, Callable, List, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
12
  from careamics.transforms import Compose
13
13
 
14
- from ..config import DataConfig, InferenceConfig
15
- from ..config.tile_information import TileInformation
14
+ from ..config import DataConfig
16
15
  from ..config.transformations import NormalizeModel
17
16
  from ..utils.logging import get_logger
18
- from .dataset_utils import read_tiff, reshape_array
17
+ from .dataset_utils import read_tiff
19
18
  from .patching.patching import (
19
+ PatchedOutput,
20
20
  prepare_patches_supervised,
21
21
  prepare_patches_supervised_array,
22
22
  prepare_patches_unsupervised,
23
23
  prepare_patches_unsupervised_array,
24
24
  )
25
- from .patching.tiled_patching import extract_tiles
26
25
 
27
26
  logger = get_logger(__name__)
28
27
 
@@ -32,11 +31,12 @@ class InMemoryDataset(Dataset):
32
31
 
33
32
  Parameters
34
33
  ----------
35
- data_config : DataConfig
34
+ data_config : CAREamics DataConfig
35
+ (see careamics.config.data_model.DataConfig)
36
36
  Data configuration.
37
- inputs : Union[np.ndarray, List[Path]]
37
+ inputs : numpy.ndarray or list[pathlib.Path]
38
38
  Input data.
39
- input_target : Optional[Union[np.ndarray, List[Path]]], optional
39
+ input_target : numpy.ndarray or list[pathlib.Path], optional
40
40
  Target data, by default None.
41
41
  read_source_func : Callable, optional
42
42
  Read source function for custom types, by default read_tiff.
@@ -47,8 +47,8 @@ class InMemoryDataset(Dataset):
47
47
  def __init__(
48
48
  self,
49
49
  data_config: DataConfig,
50
- inputs: Union[np.ndarray, List[Path]],
51
- input_target: Optional[Union[np.ndarray, List[Path]]] = None,
50
+ inputs: Union[np.ndarray, list[Path]],
51
+ input_target: Optional[Union[np.ndarray, list[Path]]] = None,
52
52
  read_source_func: Callable = read_tiff,
53
53
  **kwargs: Any,
54
54
  ) -> None:
@@ -59,9 +59,9 @@ class InMemoryDataset(Dataset):
59
59
  ----------
60
60
  data_config : DataConfig
61
61
  Data configuration.
62
- inputs : Union[np.ndarray, List[Path]]
62
+ inputs : numpy.ndarray or list[pathlib.Path]
63
63
  Input data.
64
- input_target : Optional[Union[np.ndarray, List[Path]]], optional
64
+ input_target : numpy.ndarray or list[pathlib.Path], optional
65
65
  Target data, by default None.
66
66
  read_source_func : Callable, optional
67
67
  Read source function for custom types, by default read_tiff.
@@ -79,29 +79,51 @@ class InMemoryDataset(Dataset):
79
79
 
80
80
  # Generate patches
81
81
  supervised = self.input_targets is not None
82
- patch_data = self._prepare_patches(supervised)
82
+ patches_data = self._prepare_patches(supervised)
83
83
 
84
- # Add results to members
85
- self.patches, self.patch_targets, computed_mean, computed_std = patch_data
84
+ # Unpack the dataclass
85
+ self.data = patches_data.patches
86
+ self.data_targets = patches_data.targets
86
87
 
87
- if not self.data_config.mean or not self.data_config.std:
88
- self.mean, self.std = computed_mean, computed_std
89
- logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
90
-
91
- # update mean and std in configuration
92
- # the object is mutable and should then be recorded in the CAREamist obj
93
- self.data_config.set_mean_and_std(self.mean, self.std)
88
+ if self.data_config.image_means is None:
89
+ self.image_means = patches_data.image_stats.means
90
+ self.image_stds = patches_data.image_stats.stds
91
+ logger.info(
92
+ f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
93
+ )
94
94
  else:
95
- self.mean, self.std = self.data_config.mean, self.data_config.std
95
+ self.image_means = self.data_config.image_means
96
+ self.image_stds = self.data_config.image_stds
96
97
 
98
+ if self.data_config.target_means is None:
99
+ self.target_means = patches_data.target_stats.means
100
+ self.target_stds = patches_data.target_stats.stds
101
+ else:
102
+ self.target_means = self.data_config.target_means
103
+ self.target_stds = self.data_config.target_stds
104
+
105
+ # update mean and std in configuration
106
+ # the object is mutable and should then be recorded in the CAREamist obj
107
+ self.data_config.set_mean_and_std(
108
+ image_means=self.image_means,
109
+ image_stds=self.image_stds,
110
+ target_means=self.target_means,
111
+ target_stds=self.target_stds,
112
+ )
97
113
  # get transforms
98
114
  self.patch_transform = Compose(
99
- transform_list=self.data_config.transforms,
115
+ transform_list=[
116
+ NormalizeModel(
117
+ image_means=self.image_means,
118
+ image_stds=self.image_stds,
119
+ target_means=self.target_means,
120
+ target_stds=self.target_stds,
121
+ )
122
+ ]
123
+ + self.data_config.transforms,
100
124
  )
101
125
 
102
- def _prepare_patches(
103
- self, supervised: bool
104
- ) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
126
+ def _prepare_patches(self, supervised: bool) -> PatchedOutput:
105
127
  """
106
128
  Iterate over data source and create an array of patches.
107
129
 
@@ -112,7 +134,7 @@ class InMemoryDataset(Dataset):
112
134
 
113
135
  Returns
114
136
  -------
115
- np.ndarray
137
+ numpy.ndarray
116
138
  Array of patches.
117
139
  """
118
140
  if supervised:
@@ -163,9 +185,9 @@ class InMemoryDataset(Dataset):
163
185
  int
164
186
  Length of the dataset.
165
187
  """
166
- return len(self.patches)
188
+ return self.data.shape[0]
167
189
 
168
- def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
190
+ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
169
191
  """
170
192
  Return the patch corresponding to the provided index.
171
193
 
@@ -176,7 +198,7 @@ class InMemoryDataset(Dataset):
176
198
 
177
199
  Returns
178
200
  -------
179
- Tuple[np.ndarray]
201
+ tuple of numpy.ndarray
180
202
  Patch.
181
203
 
182
204
  Raises
@@ -184,16 +206,16 @@ class InMemoryDataset(Dataset):
184
206
  ValueError
185
207
  If dataset mean and std are not set.
186
208
  """
187
- patch = self.patches[index]
209
+ patch = self.data[index]
188
210
 
189
211
  # if there is a target
190
- if self.patch_targets is not None:
212
+ if self.data_targets is not None:
191
213
  # get target
192
- target = self.patch_targets[index]
214
+ target = self.data_targets[index]
193
215
 
194
216
  return self.patch_transform(patch=patch, target=target)
195
217
 
196
- elif self.data_config.has_n2v_manipulate():
218
+ elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
197
219
  return self.patch_transform(patch=patch)
198
220
  else:
199
221
  raise ValueError(
@@ -219,7 +241,7 @@ class InMemoryDataset(Dataset):
219
241
 
220
242
  Returns
221
243
  -------
222
- InMemoryDataset
244
+ CAREamics InMemoryDataset
223
245
  New dataset with the extracted patches.
224
246
 
225
247
  Raises
@@ -249,151 +271,24 @@ class InMemoryDataset(Dataset):
249
271
  indices = np.random.choice(total_patches, n_patches, replace=False)
250
272
 
251
273
  # extract patches
252
- val_patches = self.patches[indices]
274
+ val_patches = self.data[indices]
253
275
 
254
276
  # remove patches from self.patch
255
- self.patches = np.delete(self.patches, indices, axis=0)
277
+ self.data = np.delete(self.data, indices, axis=0)
256
278
 
257
279
  # same for targets
258
- if self.patch_targets is not None:
259
- val_targets = self.patch_targets[indices]
260
- self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
280
+ if self.data_targets is not None:
281
+ val_targets = self.data_targets[indices]
282
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
261
283
 
262
284
  # clone the dataset
263
285
  dataset = copy.deepcopy(self)
264
286
 
265
287
  # reassign patches
266
- dataset.patches = val_patches
288
+ dataset.data = val_patches
267
289
 
268
290
  # reassign targets
269
- if self.patch_targets is not None:
270
- dataset.patch_targets = val_targets
291
+ if self.data_targets is not None:
292
+ dataset.data_targets = val_targets
271
293
 
272
294
  return dataset
273
-
274
-
275
- class InMemoryPredictionDataset(Dataset):
276
- """
277
- Dataset storing data in memory and allowing generating patches from it.
278
-
279
- Parameters
280
- ----------
281
- prediction_config : InferenceConfig
282
- Prediction configuration.
283
- inputs : np.ndarray
284
- Input data.
285
- data_target : Optional[np.ndarray], optional
286
- Target data, by default None.
287
- read_source_func : Optional[Callable], optional
288
- Read source function for custom types, by default read_tiff.
289
- """
290
-
291
- def __init__(
292
- self,
293
- prediction_config: InferenceConfig,
294
- inputs: np.ndarray,
295
- data_target: Optional[np.ndarray] = None,
296
- read_source_func: Optional[Callable] = read_tiff,
297
- ) -> None:
298
- """Constructor.
299
-
300
- Parameters
301
- ----------
302
- prediction_config : InferenceConfig
303
- Prediction configuration.
304
- inputs : np.ndarray
305
- Input data.
306
- data_target : Optional[np.ndarray], optional
307
- Target data, by default None.
308
- read_source_func : Optional[Callable], optional
309
- Read source function for custom types, by default read_tiff.
310
-
311
- Raises
312
- ------
313
- ValueError
314
- If data_path is not a directory.
315
- """
316
- self.pred_config = prediction_config
317
- self.input_array = inputs
318
- self.axes = self.pred_config.axes
319
- self.tile_size = self.pred_config.tile_size
320
- self.tile_overlap = self.pred_config.tile_overlap
321
- self.mean = self.pred_config.mean
322
- self.std = self.pred_config.std
323
- self.data_target = data_target
324
-
325
- # tiling only if both tile size and overlap are provided
326
- self.tiling = self.tile_size is not None and self.tile_overlap is not None
327
-
328
- # read function
329
- self.read_source_func = read_source_func
330
-
331
- # Generate patches
332
- self.data = self._prepare_tiles()
333
- self.mean, self.std = self.pred_config.mean, self.pred_config.std
334
-
335
- # get transforms
336
- self.patch_transform = Compose(
337
- transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
338
- )
339
-
340
- def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
341
- """
342
- Iterate over data source and create an array of patches.
343
-
344
- Returns
345
- -------
346
- List[XArrayTile]
347
- List of tiles.
348
- """
349
- # reshape array
350
- reshaped_sample = reshape_array(self.input_array, self.axes)
351
-
352
- if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
353
- # generate patches, which returns a generator
354
- patch_generator = extract_tiles(
355
- arr=reshaped_sample,
356
- tile_size=self.tile_size,
357
- overlaps=self.tile_overlap,
358
- )
359
- patches_list = list(patch_generator)
360
-
361
- if len(patches_list) == 0:
362
- raise ValueError("No tiles generated, ")
363
-
364
- return patches_list
365
- else:
366
- array_shape = reshaped_sample.squeeze().shape
367
- return [(reshaped_sample, TileInformation(array_shape=array_shape))]
368
-
369
- def __len__(self) -> int:
370
- """
371
- Return the length of the dataset.
372
-
373
- Returns
374
- -------
375
- int
376
- Length of the dataset.
377
- """
378
- return len(self.data)
379
-
380
- def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
381
- """
382
- Return the patch corresponding to the provided index.
383
-
384
- Parameters
385
- ----------
386
- index : int
387
- Index of the patch to return.
388
-
389
- Returns
390
- -------
391
- Tuple[np.ndarray, TileInformation]
392
- Transformed patch.
393
- """
394
- tile_array, tile_info = self.data[index]
395
-
396
- # Apply transforms
397
- transformed_tile, _ = self.patch_transform(patch=tile_array)
398
-
399
- return transformed_tile, tile_info
@@ -0,0 +1,88 @@
1
+ """In-memory prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.transformations import NormalizeModel
12
+ from .dataset_utils import reshape_array
13
+
14
+
15
+ class InMemoryPredDataset(Dataset):
16
+ """Simple prediction dataset returning images along the sample axis.
17
+
18
+ Parameters
19
+ ----------
20
+ prediction_config : InferenceConfig
21
+ Prediction configuration.
22
+ inputs : NDArray
23
+ Input data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ prediction_config: InferenceConfig,
29
+ inputs: NDArray,
30
+ ) -> None:
31
+ """Constructor.
32
+
33
+ Parameters
34
+ ----------
35
+ prediction_config : InferenceConfig
36
+ Prediction configuration.
37
+ inputs : NDArray
38
+ Input data.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If data_path is not a directory.
44
+ """
45
+ self.pred_config = prediction_config
46
+ self.input_array = inputs
47
+ self.axes = self.pred_config.axes
48
+ self.image_means = self.pred_config.image_means
49
+ self.image_stds = self.pred_config.image_stds
50
+
51
+ # Reshape data
52
+ self.data = reshape_array(self.input_array, self.axes)
53
+
54
+ # get transforms
55
+ self.patch_transform = Compose(
56
+ transform_list=[
57
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
58
+ ],
59
+ )
60
+
61
+ def __len__(self) -> int:
62
+ """
63
+ Return the length of the dataset.
64
+
65
+ Returns
66
+ -------
67
+ int
68
+ Length of the dataset.
69
+ """
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, index: int) -> NDArray:
73
+ """
74
+ Return the patch corresponding to the provided index.
75
+
76
+ Parameters
77
+ ----------
78
+ index : int
79
+ Index of the patch to return.
80
+
81
+ Returns
82
+ -------
83
+ NDArray
84
+ Transformed patch.
85
+ """
86
+ transformed_patch, _ = self.patch_transform(patch=self.data[index])
87
+
88
+ return transformed_patch
@@ -0,0 +1,129 @@
1
+ """In-memory tiled prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.tile_information import TileInformation
12
+ from ..config.transformations import NormalizeModel
13
+ from .dataset_utils import reshape_array
14
+ from .tiling import extract_tiles
15
+
16
+
17
+ class InMemoryTiledPredDataset(Dataset):
18
+ """Prediction dataset storing data in memory and returning tiles of each image.
19
+
20
+ Parameters
21
+ ----------
22
+ prediction_config : InferenceConfig
23
+ Prediction configuration.
24
+ inputs : NDArray
25
+ Input data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ prediction_config: InferenceConfig,
31
+ inputs: NDArray,
32
+ ) -> None:
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ prediction_config : InferenceConfig
38
+ Prediction configuration.
39
+ inputs : NDArray
40
+ Input data.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If data_path is not a directory.
46
+ """
47
+ if (
48
+ prediction_config.tile_size is None
49
+ or prediction_config.tile_overlap is None
50
+ ):
51
+ raise ValueError(
52
+ "Tile size and overlap must be provided to use the tiled prediction "
53
+ "dataset."
54
+ )
55
+
56
+ self.pred_config = prediction_config
57
+ self.input_array = inputs
58
+ self.axes = self.pred_config.axes
59
+ self.tile_size = prediction_config.tile_size
60
+ self.tile_overlap = prediction_config.tile_overlap
61
+ self.image_means = self.pred_config.image_means
62
+ self.image_stds = self.pred_config.image_stds
63
+
64
+ # Generate patches
65
+ self.data = self._prepare_tiles()
66
+
67
+ # get transforms
68
+ self.patch_transform = Compose(
69
+ transform_list=[
70
+ NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
71
+ ],
72
+ )
73
+
74
+ def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
75
+ """
76
+ Iterate over data source and create an array of patches.
77
+
78
+ Returns
79
+ -------
80
+ list of tuples of NDArray and TileInformation
81
+ List of tiles and tile information.
82
+ """
83
+ # reshape array
84
+ reshaped_sample = reshape_array(self.input_array, self.axes)
85
+
86
+ # generate patches, which returns a generator
87
+ patch_generator = extract_tiles(
88
+ arr=reshaped_sample,
89
+ tile_size=self.tile_size,
90
+ overlaps=self.tile_overlap,
91
+ )
92
+ patches_list = list(patch_generator)
93
+
94
+ if len(patches_list) == 0:
95
+ raise ValueError("No tiles generated, ")
96
+
97
+ return patches_list
98
+
99
+ def __len__(self) -> int:
100
+ """
101
+ Return the length of the dataset.
102
+
103
+ Returns
104
+ -------
105
+ int
106
+ Length of the dataset.
107
+ """
108
+ return len(self.data)
109
+
110
+ def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
111
+ """
112
+ Return the patch corresponding to the provided index.
113
+
114
+ Parameters
115
+ ----------
116
+ index : int
117
+ Index of the patch to return.
118
+
119
+ Returns
120
+ -------
121
+ tuple of NDArray and TileInformation
122
+ Transformed patch.
123
+ """
124
+ tile_array, tile_info = self.data[index]
125
+
126
+ # Apply transforms
127
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
128
+
129
+ return transformed_tile, tile_info