careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +83 -62
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -0
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +2 -0
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +1 -79
  11. careamics/config/configuration_model.py +12 -7
  12. careamics/config/data_model.py +29 -10
  13. careamics/config/inference_model.py +12 -2
  14. careamics/config/optimizer_models.py +6 -0
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/tile_information.py +10 -0
  17. careamics/config/training_model.py +5 -1
  18. careamics/dataset/dataset_utils/__init__.py +0 -6
  19. careamics/dataset/dataset_utils/file_utils.py +1 -1
  20. careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
  21. careamics/dataset/in_memory_dataset.py +37 -21
  22. careamics/dataset/iterable_dataset.py +38 -34
  23. careamics/dataset/iterable_pred_dataset.py +2 -1
  24. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  25. careamics/dataset/patching/patching.py +53 -37
  26. careamics/file_io/__init__.py +7 -0
  27. careamics/file_io/read/__init__.py +11 -0
  28. careamics/file_io/read/get_func.py +56 -0
  29. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
  30. careamics/file_io/write/__init__.py +9 -0
  31. careamics/file_io/write/get_func.py +59 -0
  32. careamics/file_io/write/tiff.py +39 -0
  33. careamics/lightning/__init__.py +17 -0
  34. careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
  35. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
  36. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
  37. careamics/model_io/bmz_io.py +1 -1
  38. careamics/model_io/model_io_utils.py +1 -1
  39. careamics/prediction_utils/__init__.py +0 -2
  40. careamics/prediction_utils/prediction_outputs.py +18 -46
  41. careamics/prediction_utils/stitch_prediction.py +17 -14
  42. careamics/utils/__init__.py +2 -0
  43. careamics/utils/autocorrelation.py +40 -0
  44. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
  45. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
  46. careamics/config/configuration_example.py +0 -86
  47. careamics/dataset/dataset_utils/read_utils.py +0 -27
  48. careamics/prediction_utils/create_pred_datamodule.py +0 -185
  49. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  50. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  51. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  52. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  53. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
  54. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,11 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
+ from numpy.typing import NDArray
8
9
  from torch.utils.data import DataLoader
9
10
 
10
11
  from careamics.config import InferenceConfig
@@ -15,11 +16,9 @@ from careamics.dataset import (
15
16
  IterablePredDataset,
16
17
  IterableTiledPredDataset,
17
18
  )
18
- from careamics.dataset.dataset_utils import (
19
- get_read_func,
20
- list_files,
21
- )
19
+ from careamics.dataset.dataset_utils import list_files
22
20
  from careamics.dataset.tiling.collate_tiles import collate_tiles
21
+ from careamics.file_io.read import get_read_func
23
22
  from careamics.utils import get_logger
24
23
 
25
24
  PredictDatasetType = Union[
@@ -32,7 +31,7 @@ PredictDatasetType = Union[
32
31
  logger = get_logger(__name__)
33
32
 
34
33
 
35
- class CAREamicsPredictData(L.LightningDataModule):
34
+ class PredictDataModule(L.LightningDataModule):
36
35
  """
37
36
  CAREamics Lightning prediction data module.
38
37
 
@@ -51,9 +50,9 @@ class CAREamicsPredictData(L.LightningDataModule):
51
50
  ----------
52
51
  pred_config : InferenceModel
53
52
  Pydantic model for CAREamics prediction configuration.
54
- pred_data : Union[Path, str, np.ndarray]
53
+ pred_data : pathlib.Path or str or numpy.ndarray
55
54
  Prediction data, can be a path to a folder, a file or a numpy array.
56
- read_source_func : Optional[Callable], optional
55
+ read_source_func : Callable, optional
57
56
  Function to read custom types, by default None.
58
57
  extension_filter : str, optional
59
58
  Filter to filter file extensions for custom types, by default "".
@@ -64,7 +63,7 @@ class CAREamicsPredictData(L.LightningDataModule):
64
63
  def __init__(
65
64
  self,
66
65
  pred_config: InferenceConfig,
67
- pred_data: Union[Path, str, np.ndarray],
66
+ pred_data: Union[Path, str, NDArray],
68
67
  read_source_func: Optional[Callable] = None,
69
68
  extension_filter: str = "",
70
69
  dataloader_params: Optional[dict] = None,
@@ -87,9 +86,9 @@ class CAREamicsPredictData(L.LightningDataModule):
87
86
  ----------
88
87
  pred_config : InferenceModel
89
88
  Pydantic model for CAREamics prediction configuration.
90
- pred_data : Union[Path, str, np.ndarray]
89
+ pred_data : pathlib.Path or str or numpy.ndarray
91
90
  Prediction data, can be a path to a folder, a file or a numpy array.
92
- read_source_func : Optional[Callable], optional
91
+ read_source_func : Callable, optional
93
92
  Function to read custom types, by default None.
94
93
  extension_filter : str, optional
95
94
  Filter to filter file extensions for custom types, by default "".
@@ -222,33 +221,36 @@ class CAREamicsPredictData(L.LightningDataModule):
222
221
  batch_size=self.batch_size,
223
222
  collate_fn=collate_tiles if self.tiled else None,
224
223
  **self.dataloader_params,
225
- ) # TODO check workers are used
224
+ )
226
225
 
227
226
 
228
- class PredictDataWrapper(CAREamicsPredictData):
229
- """
230
- Wrapper around the CAREamics inference Lightning data module.
231
-
232
- This class is used to explicitely pass the parameters usually contained in a
227
+ def create_predict_datamodule(
228
+ pred_data: Union[str, Path, NDArray],
229
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
230
+ axes: str,
231
+ image_means: list[float],
232
+ image_stds: list[float],
233
+ tile_size: Optional[tuple[int, ...]] = None,
234
+ tile_overlap: Optional[tuple[int, ...]] = None,
235
+ batch_size: int = 1,
236
+ tta_transforms: bool = True,
237
+ read_source_func: Optional[Callable] = None,
238
+ extension_filter: str = "",
239
+ dataloader_params: Optional[dict] = None,
240
+ ) -> PredictDataModule:
241
+ """Create a CAREamics prediction Lightning datamodule.
242
+
243
+ This function is used to explicitely pass the parameters usually contained in an
233
244
  `inference_model` configuration.
234
245
 
235
246
  Since the lightning datamodule has no access to the model, make sure that the
236
247
  parameters passed to the datamodule are consistent with the model's requirements
237
- and are coherent.
248
+ and are coherent. This can be done by creating a `Configuration` object beforehand
249
+ and passing its parameters to the different Lightning modules.
238
250
 
239
251
  The data module can be used with Path, str or numpy arrays. To use array data, set
240
252
  `data_type` to `array` and pass a numpy array to `train_data`.
241
253
 
242
- The default transformations applied to the images are defined in
243
- `careamics.config.inference_model`. To use different transformations, pass a list
244
- of transforms. See examples
245
- for more details.
246
-
247
- The `mean` and `std` parameters are only used if Normalization is defined either
248
- in the default transformations or in the `transforms` parameter. If you pass a
249
- `Normalization` transform in a list as `transforms`, then the mean and std
250
- parameters will be overwritten by those passed to this method.
251
-
252
254
  By default, CAREamics only supports types defined in
253
255
  `careamics.config.support.SupportedData`. To read custom data types, you can set
254
256
  `data_type` to `custom` and provide a function that returns a numpy array from a
@@ -259,113 +261,73 @@ class PredictDataWrapper(CAREamicsPredictData):
259
261
  dataloaders, except for `batch_size`, which is set by the `batch_size`
260
262
  parameter.
261
263
 
262
- Note that if you are using a UNet model and tiling, the tile size must be
263
- divisible in every dimension by 2**d, where d is the depth of the model. This
264
- avoids artefacts arising from the broken shift invariance induced by the
265
- pooling layers of the UNet. If your image has less dimensions, as it may
266
- happen in the Z dimension, consider padding your image.
267
-
268
264
  Parameters
269
265
  ----------
270
- pred_data : Union[str, Path, np.ndarray]
266
+ pred_data : str or pathlib.Path or numpy.ndarray
271
267
  Prediction data.
272
- data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
268
+ data_type : {"array", "tiff", "custom"}
273
269
  Data type, see `SupportedData` for available options.
270
+ axes : str
271
+ Axes of the data, choosen among SCZYX.
274
272
  image_means : list of float
275
273
  Mean values for normalization, only used if Normalization is defined.
276
274
  image_stds : list of float
277
275
  Std values for normalization, only used if Normalization is defined.
278
- tile_size : Tuple[int, ...]
276
+ tile_size : tuple of int, optional
279
277
  Tile size, 2D or 3D tile size.
280
- tile_overlap : Tuple[int, ...]
278
+ tile_overlap : tuple of int, optional
281
279
  Tile overlap, 2D or 3D tile overlap.
282
- axes : str
283
- Axes of the data, choosen amongst SCZYX.
284
280
  batch_size : int
285
281
  Batch size.
286
282
  tta_transforms : bool, optional
287
283
  Use test time augmentation, by default True.
288
- read_source_func : Optional[Callable], optional
284
+ read_source_func : Callable, optional
289
285
  Function to read the source data, used if `data_type` is `custom`, by
290
286
  default None.
291
287
  extension_filter : str, optional
292
288
  Filter for file extensions, used if `data_type` is `custom`, by default "".
293
289
  dataloader_params : dict, optional
294
290
  Pytorch dataloader parameters, by default {}.
295
- """
296
291
 
297
- def __init__(
298
- self,
299
- pred_data: Union[str, Path, np.ndarray],
300
- data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
301
- image_means=list[float],
302
- image_stds=list[float],
303
- tile_size: Optional[Tuple[int, ...]] = None,
304
- tile_overlap: Optional[Tuple[int, ...]] = None,
305
- axes: str = "YX",
306
- batch_size: int = 1,
307
- tta_transforms: bool = True,
308
- read_source_func: Optional[Callable] = None,
309
- extension_filter: str = "",
310
- dataloader_params: Optional[dict] = None,
311
- ) -> None:
312
- """
313
- Constructor.
292
+ Returns
293
+ -------
294
+ PredictDataModule
295
+ CAREamics prediction datamodule.
314
296
 
315
- Parameters
316
- ----------
317
- pred_data : Union[str, Path, np.ndarray]
318
- Prediction data.
319
- data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
320
- Data type, see `SupportedData` for available options.
321
- image_means : list of float
322
- Mean values for normalization, only used if Normalization is defined.
323
- image_stds : list of float
324
- Std values for normalization, only used if Normalization is defined.
325
- tile_size : List[int]
326
- Tile size, 2D or 3D tile size.
327
- tile_overlap : List[int]
328
- Tile overlap, 2D or 3D tile overlap.
329
- axes : str
330
- Axes of the data, choosen amongst SCZYX.
331
- batch_size : int
332
- Batch size.
333
- tta_transforms : bool, optional
334
- Use test time augmentation, by default True.
335
- read_source_func : Optional[Callable], optional
336
- Function to read the source data, used if `data_type` is `custom`, by
337
- default None.
338
- extension_filter : str, optional
339
- Filter for file extensions, used if `data_type` is `custom`, by default "".
340
- dataloader_params : dict, optional
341
- Pytorch dataloader parameters, by default {}.
342
- """
343
- if dataloader_params is None:
344
- dataloader_params = {}
345
- prediction_dict: Dict[str, Any] = {
346
- "data_type": data_type,
347
- "tile_size": tile_size,
348
- "tile_overlap": tile_overlap,
349
- "axes": axes,
350
- "image_means": image_means,
351
- "image_stds": image_stds,
352
- "tta": tta_transforms,
353
- "batch_size": batch_size,
354
- "transforms": [],
355
- }
356
-
357
- # validate configuration
358
- self.prediction_config = InferenceConfig(**prediction_dict)
359
-
360
- # sanity check on the dataloader parameters
361
- if "batch_size" in dataloader_params:
362
- # remove it
363
- del dataloader_params["batch_size"]
364
-
365
- super().__init__(
366
- pred_config=self.prediction_config,
367
- pred_data=pred_data,
368
- read_source_func=read_source_func,
369
- extension_filter=extension_filter,
370
- dataloader_params=dataloader_params,
371
- )
297
+ Notes
298
+ -----
299
+ If you are using a UNet model and tiling, the tile size must be
300
+ divisible in every dimension by 2**d, where d is the depth of the model. This
301
+ avoids artefacts arising from the broken shift invariance induced by the
302
+ pooling layers of the UNet. If your image has less dimensions, as it may
303
+ happen in the Z dimension, consider padding your image.
304
+ """
305
+ if dataloader_params is None:
306
+ dataloader_params = {}
307
+
308
+ prediction_dict: dict[str, Any] = {
309
+ "data_type": data_type,
310
+ "tile_size": tile_size,
311
+ "tile_overlap": tile_overlap,
312
+ "axes": axes,
313
+ "image_means": image_means,
314
+ "image_stds": image_stds,
315
+ "tta_transforms": tta_transforms,
316
+ "batch_size": batch_size,
317
+ }
318
+
319
+ # validate configuration
320
+ prediction_config = InferenceConfig(**prediction_dict)
321
+
322
+ # sanity check on the dataloader parameters
323
+ if "batch_size" in dataloader_params:
324
+ # remove it
325
+ del dataloader_params["batch_size"]
326
+
327
+ return PredictDataModule(
328
+ pred_config=prediction_config,
329
+ pred_data=pred_data,
330
+ read_source_func=read_source_func,
331
+ extension_filter=extension_filter,
332
+ dataloader_params=dataloader_params,
333
+ )