careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -1,69 +1,37 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
+ from numpy.typing import NDArray
8
9
  from torch.utils.data import DataLoader
9
- from torch.utils.data.dataloader import default_collate
10
10
 
11
11
  from careamics.config import InferenceConfig
12
12
  from careamics.config.support import SupportedData
13
- from careamics.config.tile_information import TileInformation
14
- from careamics.dataset.dataset_utils import (
15
- get_read_func,
16
- list_files,
17
- )
18
- from careamics.dataset.in_memory_dataset import (
19
- InMemoryPredictionDataset,
20
- )
21
- from careamics.dataset.iterable_dataset import (
22
- IterablePredictionDataset,
13
+ from careamics.dataset import (
14
+ InMemoryPredDataset,
15
+ InMemoryTiledPredDataset,
16
+ IterablePredDataset,
17
+ IterableTiledPredDataset,
23
18
  )
19
+ from careamics.dataset.dataset_utils import list_files
20
+ from careamics.dataset.tiling.collate_tiles import collate_tiles
21
+ from careamics.file_io.read import get_read_func
24
22
  from careamics.utils import get_logger
25
23
 
26
- PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
24
+ PredictDatasetType = Union[
25
+ InMemoryPredDataset,
26
+ InMemoryTiledPredDataset,
27
+ IterablePredDataset,
28
+ IterableTiledPredDataset,
29
+ ]
27
30
 
28
31
  logger = get_logger(__name__)
29
32
 
30
33
 
31
- def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
32
- """
33
- Collate tiles received from CAREamics prediction dataloader.
34
-
35
- CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
36
- case of non-tiled data, this function will return the arrays. In case of tiled data,
37
- it will return the arrays, the last tile flag, the overlap crop coordinates and the
38
- stitch coordinates.
39
-
40
- Parameters
41
- ----------
42
- batch : List[Tuple[np.ndarray, TileInformation], ...]
43
- Batch of tiles.
44
-
45
- Returns
46
- -------
47
- Any
48
- Collated batch.
49
- """
50
- first_tile_info: TileInformation = batch[0][1]
51
- # if not tiled, then return arrays
52
- if not first_tile_info.tiled:
53
- arrays, _ = zip(*batch)
54
-
55
- return default_collate(arrays)
56
- # else we explicit the last_tile flag and coordinates
57
- else:
58
- new_batch = [
59
- (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
60
- for tile, t in batch
61
- ]
62
-
63
- return default_collate(new_batch)
64
-
65
-
66
- class CAREamicsPredictData(L.LightningDataModule):
34
+ class PredictDataModule(L.LightningDataModule):
67
35
  """
68
36
  CAREamics Lightning prediction data module.
69
37
 
@@ -82,9 +50,9 @@ class CAREamicsPredictData(L.LightningDataModule):
82
50
  ----------
83
51
  pred_config : InferenceModel
84
52
  Pydantic model for CAREamics prediction configuration.
85
- pred_data : Union[Path, str, np.ndarray]
53
+ pred_data : pathlib.Path or str or numpy.ndarray
86
54
  Prediction data, can be a path to a folder, a file or a numpy array.
87
- read_source_func : Optional[Callable], optional
55
+ read_source_func : Callable, optional
88
56
  Function to read custom types, by default None.
89
57
  extension_filter : str, optional
90
58
  Filter to filter file extensions for custom types, by default "".
@@ -95,7 +63,7 @@ class CAREamicsPredictData(L.LightningDataModule):
95
63
  def __init__(
96
64
  self,
97
65
  pred_config: InferenceConfig,
98
- pred_data: Union[Path, str, np.ndarray],
66
+ pred_data: Union[Path, str, NDArray],
99
67
  read_source_func: Optional[Callable] = None,
100
68
  extension_filter: str = "",
101
69
  dataloader_params: Optional[dict] = None,
@@ -118,9 +86,9 @@ class CAREamicsPredictData(L.LightningDataModule):
118
86
  ----------
119
87
  pred_config : InferenceModel
120
88
  Pydantic model for CAREamics prediction configuration.
121
- pred_data : Union[Path, str, np.ndarray]
89
+ pred_data : pathlib.Path or str or numpy.ndarray
122
90
  Prediction data, can be a path to a folder, a file or a numpy array.
123
- read_source_func : Optional[Callable], optional
91
+ read_source_func : Callable, optional
124
92
  Function to read custom types, by default None.
125
93
  extension_filter : str, optional
126
94
  Filter to filter file extensions for custom types, by default "".
@@ -182,6 +150,9 @@ class CAREamicsPredictData(L.LightningDataModule):
182
150
  self.tile_size = pred_config.tile_size
183
151
  self.tile_overlap = pred_config.tile_overlap
184
152
 
153
+ # check if it is tiled
154
+ self.tiled = self.tile_size is not None and self.tile_overlap is not None
155
+
185
156
  # read source function
186
157
  if pred_config.data_type == SupportedData.CUSTOM:
187
158
  # mypy check
@@ -212,17 +183,29 @@ class CAREamicsPredictData(L.LightningDataModule):
212
183
  """
213
184
  # if numpy array
214
185
  if self.data_type == SupportedData.ARRAY:
215
- # prediction dataset
216
- self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
217
- prediction_config=self.prediction_config,
218
- inputs=self.pred_data,
219
- )
186
+ if self.tiled:
187
+ self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
188
+ prediction_config=self.prediction_config,
189
+ inputs=self.pred_data,
190
+ )
191
+ else:
192
+ self.predict_dataset = InMemoryPredDataset(
193
+ prediction_config=self.prediction_config,
194
+ inputs=self.pred_data,
195
+ )
220
196
  else:
221
- self.predict_dataset = IterablePredictionDataset(
222
- prediction_config=self.prediction_config,
223
- src_files=self.pred_files,
224
- read_source_func=self.read_source_func,
225
- )
197
+ if self.tiled:
198
+ self.predict_dataset = IterableTiledPredDataset(
199
+ prediction_config=self.prediction_config,
200
+ src_files=self.pred_files,
201
+ read_source_func=self.read_source_func,
202
+ )
203
+ else:
204
+ self.predict_dataset = IterablePredDataset(
205
+ prediction_config=self.prediction_config,
206
+ src_files=self.pred_files,
207
+ read_source_func=self.read_source_func,
208
+ )
226
209
 
227
210
  def predict_dataloader(self) -> DataLoader:
228
211
  """
@@ -236,35 +219,38 @@ class CAREamicsPredictData(L.LightningDataModule):
236
219
  return DataLoader(
237
220
  self.predict_dataset,
238
221
  batch_size=self.batch_size,
239
- collate_fn=_collate_tiles,
222
+ collate_fn=collate_tiles if self.tiled else None,
240
223
  **self.dataloader_params,
241
- ) # TODO check workers are used
242
-
224
+ )
243
225
 
244
- class PredictDataWrapper(CAREamicsPredictData):
245
- """
246
- Wrapper around the CAREamics inference Lightning data module.
247
226
 
248
- This class is used to explicitely pass the parameters usually contained in a
227
+ def create_predict_datamodule(
228
+ pred_data: Union[str, Path, NDArray],
229
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
230
+ axes: str,
231
+ image_means: list[float],
232
+ image_stds: list[float],
233
+ tile_size: Optional[tuple[int, ...]] = None,
234
+ tile_overlap: Optional[tuple[int, ...]] = None,
235
+ batch_size: int = 1,
236
+ tta_transforms: bool = True,
237
+ read_source_func: Optional[Callable] = None,
238
+ extension_filter: str = "",
239
+ dataloader_params: Optional[dict] = None,
240
+ ) -> PredictDataModule:
241
+ """Create a CAREamics prediction Lightning datamodule.
242
+
243
+ This function is used to explicitely pass the parameters usually contained in an
249
244
  `inference_model` configuration.
250
245
 
251
246
  Since the lightning datamodule has no access to the model, make sure that the
252
247
  parameters passed to the datamodule are consistent with the model's requirements
253
- and are coherent.
248
+ and are coherent. This can be done by creating a `Configuration` object beforehand
249
+ and passing its parameters to the different Lightning modules.
254
250
 
255
251
  The data module can be used with Path, str or numpy arrays. To use array data, set
256
252
  `data_type` to `array` and pass a numpy array to `train_data`.
257
253
 
258
- The default transformations applied to the images are defined in
259
- `careamics.config.inference_model`. To use different transformations, pass a list
260
- of transforms. See examples
261
- for more details.
262
-
263
- The `mean` and `std` parameters are only used if Normalization is defined either
264
- in the default transformations or in the `transforms` parameter. If you pass a
265
- `Normalization` transform in a list as `transforms`, then the mean and std
266
- parameters will be overwritten by those passed to this method.
267
-
268
254
  By default, CAREamics only supports types defined in
269
255
  `careamics.config.support.SupportedData`. To read custom data types, you can set
270
256
  `data_type` to `custom` and provide a function that returns a numpy array from a
@@ -275,117 +261,73 @@ class PredictDataWrapper(CAREamicsPredictData):
275
261
  dataloaders, except for `batch_size`, which is set by the `batch_size`
276
262
  parameter.
277
263
 
278
- Note that if you are using a UNet model and tiling, the tile size must be
279
- divisible in every dimension by 2**d, where d is the depth of the model. This
280
- avoids artefacts arising from the broken shift invariance induced by the
281
- pooling layers of the UNet. If your image has less dimensions, as it may
282
- happen in the Z dimension, consider padding your image.
283
-
284
264
  Parameters
285
265
  ----------
286
- pred_data : Union[str, Path, np.ndarray]
266
+ pred_data : str or pathlib.Path or numpy.ndarray
287
267
  Prediction data.
288
- data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
268
+ data_type : {"array", "tiff", "custom"}
289
269
  Data type, see `SupportedData` for available options.
290
- mean : float
291
- Mean value for normalization, only used if Normalization is defined in the
292
- transforms.
293
- std : float
294
- Standard deviation value for normalization, only used if Normalization is
295
- defined in the transform.
296
- tile_size : Tuple[int, ...]
270
+ axes : str
271
+ Axes of the data, choosen among SCZYX.
272
+ image_means : list of float
273
+ Mean values for normalization, only used if Normalization is defined.
274
+ image_stds : list of float
275
+ Std values for normalization, only used if Normalization is defined.
276
+ tile_size : tuple of int, optional
297
277
  Tile size, 2D or 3D tile size.
298
- tile_overlap : Tuple[int, ...]
278
+ tile_overlap : tuple of int, optional
299
279
  Tile overlap, 2D or 3D tile overlap.
300
- axes : str
301
- Axes of the data, choosen amongst SCZYX.
302
280
  batch_size : int
303
281
  Batch size.
304
282
  tta_transforms : bool, optional
305
283
  Use test time augmentation, by default True.
306
- read_source_func : Optional[Callable], optional
284
+ read_source_func : Callable, optional
307
285
  Function to read the source data, used if `data_type` is `custom`, by
308
286
  default None.
309
287
  extension_filter : str, optional
310
288
  Filter for file extensions, used if `data_type` is `custom`, by default "".
311
289
  dataloader_params : dict, optional
312
290
  Pytorch dataloader parameters, by default {}.
313
- """
314
291
 
315
- def __init__(
316
- self,
317
- pred_data: Union[str, Path, np.ndarray],
318
- data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
319
- mean: float,
320
- std: float,
321
- tile_size: Optional[Tuple[int, ...]] = None,
322
- tile_overlap: Optional[Tuple[int, ...]] = None,
323
- axes: str = "YX",
324
- batch_size: int = 1,
325
- tta_transforms: bool = True,
326
- read_source_func: Optional[Callable] = None,
327
- extension_filter: str = "",
328
- dataloader_params: Optional[dict] = None,
329
- ) -> None:
330
- """
331
- Constructor.
292
+ Returns
293
+ -------
294
+ PredictDataModule
295
+ CAREamics prediction datamodule.
332
296
 
333
- Parameters
334
- ----------
335
- pred_data : Union[str, Path, np.ndarray]
336
- Prediction data.
337
- data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
338
- Data type, see `SupportedData` for available options.
339
- mean : float
340
- Mean value for normalization, only used if Normalization is defined in the
341
- transforms.
342
- std : float
343
- Standard deviation value for normalization, only used if Normalization is
344
- defined in the transform.
345
- tile_size : List[int]
346
- Tile size, 2D or 3D tile size.
347
- tile_overlap : List[int]
348
- Tile overlap, 2D or 3D tile overlap.
349
- axes : str
350
- Axes of the data, choosen amongst SCZYX.
351
- batch_size : int
352
- Batch size.
353
- tta_transforms : bool, optional
354
- Use test time augmentation, by default True.
355
- read_source_func : Optional[Callable], optional
356
- Function to read the source data, used if `data_type` is `custom`, by
357
- default None.
358
- extension_filter : str, optional
359
- Filter for file extensions, used if `data_type` is `custom`, by default "".
360
- dataloader_params : dict, optional
361
- Pytorch dataloader parameters, by default {}.
362
- """
363
- if dataloader_params is None:
364
- dataloader_params = {}
365
- prediction_dict: Dict[str, Any] = {
366
- "data_type": data_type,
367
- "tile_size": tile_size,
368
- "tile_overlap": tile_overlap,
369
- "axes": axes,
370
- "mean": mean,
371
- "std": std,
372
- "tta": tta_transforms,
373
- "batch_size": batch_size,
374
- "transforms": [],
375
- }
376
-
377
- # validate configuration
378
- self.prediction_config = InferenceConfig(**prediction_dict)
379
-
380
- # sanity check on the dataloader parameters
381
- if "batch_size" in dataloader_params:
382
- # remove it
383
- del dataloader_params["batch_size"]
384
-
385
- super().__init__(
386
- pred_config=self.prediction_config,
387
- pred_data=pred_data,
388
- read_source_func=read_source_func,
389
- extension_filter=extension_filter,
390
- dataloader_params=dataloader_params,
391
- )
297
+ Notes
298
+ -----
299
+ If you are using a UNet model and tiling, the tile size must be
300
+ divisible in every dimension by 2**d, where d is the depth of the model. This
301
+ avoids artefacts arising from the broken shift invariance induced by the
302
+ pooling layers of the UNet. If your image has less dimensions, as it may
303
+ happen in the Z dimension, consider padding your image.
304
+ """
305
+ if dataloader_params is None:
306
+ dataloader_params = {}
307
+
308
+ prediction_dict: dict[str, Any] = {
309
+ "data_type": data_type,
310
+ "tile_size": tile_size,
311
+ "tile_overlap": tile_overlap,
312
+ "axes": axes,
313
+ "image_means": image_means,
314
+ "image_stds": image_stds,
315
+ "tta_transforms": tta_transforms,
316
+ "batch_size": batch_size,
317
+ }
318
+
319
+ # validate configuration
320
+ prediction_config = InferenceConfig(**prediction_dict)
321
+
322
+ # sanity check on the dataloader parameters
323
+ if "batch_size" in dataloader_params:
324
+ # remove it
325
+ del dataloader_params["batch_size"]
326
+
327
+ return PredictDataModule(
328
+ pred_config=prediction_config,
329
+ pred_data=pred_data,
330
+ read_source_func=read_source_func,
331
+ extension_filter=extension_filter,
332
+ dataloader_params=dataloader_params,
333
+ )