careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,151 @@
1
+ """Zarr dataset."""
2
+
3
+ # from itertools import islice
4
+ # from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ # import numpy as np
7
+ # import torch
8
+ # import zarr
9
+
10
+ # from careamics.utils import RunningStats
11
+ # from careamics.utils.logging import get_logger
12
+
13
+ # from ..utils import normalize
14
+ # from .dataset_utils import read_zarr
15
+ # from .patching.patching import (
16
+ # generate_patches_unsupervised,
17
+ # )
18
+
19
+ # logger = get_logger(__name__)
20
+
21
+
22
+ # class ZarrDataset(torch.utils.data.IterableDataset):
23
+ # """Dataset to extract patches from a zarr storage.
24
+
25
+ # Parameters
26
+ # ----------
27
+ # data_source : Union[zarr.Group, zarr.Array]
28
+ # Zarr storage.
29
+ # axes : str
30
+ # Description of axes in format STCZYX.
31
+ # patch_extraction_method : Union[ExtractionStrategies, None]
32
+ # Patch extraction strategy, as defined in extraction_strategy.
33
+ # patch_size : Optional[Union[List[int], Tuple[int]]], optional
34
+ # Size of the patches in each dimension, by default None.
35
+ # num_patches : Optional[int], optional
36
+ # Number of patches to extract, by default None.
37
+ # mean : Optional[float], optional
38
+ # Expected mean of the dataset, by default None.
39
+ # std : Optional[float], optional
40
+ # Expected standard deviation of the dataset, by default None.
41
+ # patch_transform : Optional[Callable], optional
42
+ # Patch transform callable, by default None.
43
+ # patch_transform_params : Optional[Dict], optional
44
+ # Patch transform parameters, by default None.
45
+ # running_stats_window_perc : float, optional
46
+ # Percentage of the dataset to use for calculating the initial mean and standard
47
+ # deviation, by default 0.01.
48
+ # mode : str, optional
49
+ # train/predict, controls running stats calculation.
50
+ # """
51
+
52
+ # def __init__(
53
+ # self,
54
+ # data_source: Union[zarr.Group, zarr.Array],
55
+ # axes: str,
56
+ # patch_extraction_method: Union[SupportedExtractionStrategy, None],
57
+ # patch_size: Optional[Union[List[int], Tuple[int]]] = None,
58
+ # num_patches: Optional[int] = None,
59
+ # mean: Optional[float] = None,
60
+ # std: Optional[float] = None,
61
+ # patch_transform: Optional[Callable] = None,
62
+ # patch_transform_params: Optional[Dict] = None,
63
+ # running_stats_window_perc: float = 0.01,
64
+ # mode: str = "train",
65
+ # ) -> None:
66
+ # self.data_source = data_source
67
+ # self.axes = axes
68
+ # self.patch_extraction_method = patch_extraction_method
69
+ # self.patch_size = patch_size
70
+ # self.num_patches = num_patches
71
+ # self.mean = mean
72
+ # self.std = std
73
+ # self.patch_transform = patch_transform
74
+ # self.patch_transform_params = patch_transform_params
75
+ # self.sample = read_zarr(self.data_source, self.axes)
76
+ # self.running_stats_window = int(
77
+ # np.prod(self.sample._cdata_shape) * running_stats_window_perc
78
+ # )
79
+ # self.mode = mode
80
+ # self.running_stats = RunningStats()
81
+
82
+ # self._calculate_initial_mean_std()
83
+
84
+ # def _calculate_initial_mean_std(self):
85
+ # """Calculate initial mean and std of the dataset."""
86
+ # if self.mean is None and self.std is None:
87
+ # idxs = np.random.randint(
88
+ # 0,
89
+ # np.prod(self.sample._cdata_shape),
90
+ # size=max(1, self.running_stats_window),
91
+ # )
92
+ # random_chunks = self.sample[idxs]
93
+ # self.running_stats.init(random_chunks.mean(), random_chunks.std())
94
+
95
+ # def _generate_patches(self):
96
+ # """Generate patches from the dataset and calculates running stats.
97
+
98
+ # Yields
99
+ # ------
100
+ # np.ndarray
101
+ # Patch.
102
+ # """
103
+ # patches = generate_patches_unsupervised(
104
+ # self.sample,
105
+ # self.patch_extraction_method,
106
+ # self.patch_size,
107
+ # )
108
+
109
+ # # num_patches = np.ceil(
110
+ # # np.prod(self.sample.chunks)
111
+ # # / (np.prod(self.patch_size) * self.running_stats_window)
112
+ # # ).astype(int)
113
+
114
+ # for idx, patch in enumerate(patches):
115
+ # if self.mode != "predict":
116
+ # self.running_stats.update(patch.mean())
117
+ # if isinstance(patch, tuple):
118
+ # normalized_patch = normalize(
119
+ # img=patch[0],
120
+ # mean=self.running_stats.avg_mean.value,
121
+ # std=self.running_stats.avg_std.value,
122
+ # )
123
+ # patch = (normalized_patch, *patch[1:])
124
+ # else:
125
+ # patch = normalize(
126
+ # img=patch,
127
+ # mean=self.running_stats.avg_mean.value,
128
+ # std=self.running_stats.avg_std.value,
129
+ # )
130
+
131
+ # if self.patch_transform is not None:
132
+ # assert self.patch_transform_params is not None
133
+ # patch = self.patch_transform(patch, **self.patch_transform_params)
134
+ # if self.num_patches is not None and idx >= self.num_patches:
135
+ # return
136
+ # else:
137
+ # yield patch
138
+ # self.mean = self.running_stats.avg_mean.value
139
+ # self.std = self.running_stats.avg_std.value
140
+
141
+ # def __iter__(self):
142
+ # """
143
+ # Iterate over data source and yield single patch.
144
+
145
+ # Yields
146
+ # ------
147
+ # np.ndarray
148
+ # """
149
+ # worker_info = torch.utils.data.get_worker_info()
150
+ # num_workers = worker_info.num_workers if worker_info is not None else 1
151
+ # yield from islice(self._generate_patches(), 0, None, num_workers)
@@ -0,0 +1,15 @@
1
+ """Functions relating reading and writing image files."""
2
+
3
+ __all__ = [
4
+ "read",
5
+ "write",
6
+ "get_read_func",
7
+ "get_write_func",
8
+ "ReadFunc",
9
+ "WriteFunc",
10
+ "SupportedWriteType",
11
+ ]
12
+
13
+ from . import read, write
14
+ from .read import ReadFunc, get_read_func
15
+ from .write import SupportedWriteType, WriteFunc, get_write_func
@@ -0,0 +1,12 @@
1
+ """Functions relating to reading image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_read_func",
5
+ "read_tiff",
6
+ "read_zarr",
7
+ "ReadFunc",
8
+ ]
9
+
10
+ from .get_func import ReadFunc, get_read_func
11
+ from .tiff import read_tiff
12
+ from .zarr import read_zarr
@@ -0,0 +1,56 @@
1
+ """Module to get read functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, Protocol, Union
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import read_tiff
11
+
12
+
13
+ # This is very strict, function signature has to match including arg names
14
+ # See WriteFunc notes
15
+ class ReadFunc(Protocol):
16
+ """Protocol for type hinting read functions."""
17
+
18
+ def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
19
+ """
20
+ Type hinted callables must match this function signature (not including self).
21
+
22
+ Parameters
23
+ ----------
24
+ file_path : pathlib.Path
25
+ Path to file.
26
+ *args
27
+ Other positional arguments.
28
+ **kwargs
29
+ Other keyword arguments.
30
+ """
31
+
32
+
33
+ READ_FUNCS: Dict[SupportedData, ReadFunc] = {
34
+ SupportedData.TIFF: read_tiff,
35
+ }
36
+
37
+
38
+ def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
39
+ """
40
+ Get the read function for the data type.
41
+
42
+ Parameters
43
+ ----------
44
+ data_type : SupportedData
45
+ Data type.
46
+
47
+ Returns
48
+ -------
49
+ callable
50
+ Read function.
51
+ """
52
+ if data_type in READ_FUNCS:
53
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
54
+ return READ_FUNCS[data_type]
55
+ else:
56
+ raise NotImplementedError(f"Data type '{data_type}' is not supported.")
@@ -0,0 +1,58 @@
1
+ """Functions to read tiff images."""
2
+
3
+ import logging
4
+ from fnmatch import fnmatch
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import tifffile
9
+
10
+ from careamics.config.support import SupportedData
11
+ from careamics.utils.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
17
+ """
18
+ Read a tiff file and return a numpy array.
19
+
20
+ Parameters
21
+ ----------
22
+ file_path : Path
23
+ Path to a file.
24
+ *args : list
25
+ Additional arguments.
26
+ **kwargs : dict
27
+ Additional keyword arguments.
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray
32
+ Resulting array.
33
+
34
+ Raises
35
+ ------
36
+ ValueError
37
+ If the file failed to open.
38
+ OSError
39
+ If the file failed to open.
40
+ ValueError
41
+ If the file is not a valid tiff.
42
+ ValueError
43
+ If the data dimensions are incorrect.
44
+ ValueError
45
+ If the axes length is incorrect.
46
+ """
47
+ if fnmatch(
48
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
49
+ ):
50
+ try:
51
+ array = tifffile.imread(file_path)
52
+ except (ValueError, OSError) as e:
53
+ logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
54
+ raise e
55
+ else:
56
+ raise ValueError(f"File {file_path} is not a valid tiff.")
57
+
58
+ return array
@@ -0,0 +1,60 @@
1
+ """Function to read zarr images."""
2
+
3
+ from typing import Union
4
+
5
+ from zarr import Group, core, hierarchy, storage
6
+
7
+
8
+ def read_zarr(
9
+ zarr_source: Group, axes: str
10
+ ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
11
+ """Read a file and returns a pointer.
12
+
13
+ Parameters
14
+ ----------
15
+ zarr_source : Group
16
+ Zarr storage.
17
+ axes : str
18
+ Axes of the data.
19
+
20
+ Returns
21
+ -------
22
+ np.ndarray
23
+ Pointer to zarr storage.
24
+
25
+ Raises
26
+ ------
27
+ ValueError, OSError
28
+ if a file is not a valid tiff or damaged.
29
+ ValueError
30
+ if data dimensions are not 2, 3 or 4.
31
+ ValueError
32
+ if axes parameter from config is not consistent with data dimensions.
33
+ """
34
+ if isinstance(zarr_source, hierarchy.Group):
35
+ array = zarr_source[0]
36
+
37
+ elif isinstance(zarr_source, storage.DirectoryStore):
38
+ raise NotImplementedError("DirectoryStore not supported yet")
39
+
40
+ elif isinstance(zarr_source, core.Array):
41
+ # array should be of shape (S, (C), (Z), Y, X), iterating over S ?
42
+ if zarr_source.dtype == "O":
43
+ raise NotImplementedError("Object type not supported yet")
44
+ else:
45
+ array = zarr_source
46
+ else:
47
+ raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")
48
+
49
+ # sanity check on dimensions
50
+ if len(array.shape) < 2 or len(array.shape) > 4:
51
+ raise ValueError(
52
+ f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
53
+ )
54
+
55
+ # sanity check on axes length
56
+ if len(axes) != len(array.shape):
57
+ raise ValueError(f"Incorrect axes length (got {axes}).")
58
+
59
+ # arr = fix_axes(arr, axes)
60
+ return array
@@ -0,0 +1,15 @@
1
+ """Functions relating to writing image files of different formats."""
2
+
3
+ __all__ = [
4
+ "get_write_func",
5
+ "write_tiff",
6
+ "WriteFunc",
7
+ "SupportedWriteType",
8
+ ]
9
+
10
+ from .get_func import (
11
+ SupportedWriteType,
12
+ WriteFunc,
13
+ get_write_func,
14
+ )
15
+ from .tiff import write_tiff
@@ -0,0 +1,63 @@
1
+ """Module to get write functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Protocol
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import write_tiff
11
+
12
+ SupportedWriteType = Literal["tiff", "custom"]
13
+
14
+
15
+ # This is very strict, arguments have to be called file_path & img
16
+ # Alternative? - doesn't capture *args & **kwargs
17
+ # WriteFunc = Callable[[Path, NDArray], None]
18
+ class WriteFunc(Protocol):
19
+ """Protocol for type hinting write functions."""
20
+
21
+ def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
22
+ """
23
+ Type hinted callables must match this function signature (not including self).
24
+
25
+ Parameters
26
+ ----------
27
+ file_path : pathlib.Path
28
+ Path to file.
29
+ img : numpy.ndarray
30
+ Image data to save.
31
+ *args
32
+ Other positional arguments.
33
+ **kwargs
34
+ Other keyword arguments.
35
+ """
36
+
37
+
38
+ WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
39
+ SupportedData.TIFF: write_tiff,
40
+ }
41
+
42
+
43
+ def get_write_func(data_type: SupportedWriteType) -> WriteFunc:
44
+ """
45
+ Get the write function for the data type.
46
+
47
+ Parameters
48
+ ----------
49
+ data_type : {"tiff", "custom"}
50
+ Data type.
51
+
52
+ Returns
53
+ -------
54
+ callable
55
+ Write function.
56
+ """
57
+ # error raised here if not supported
58
+ data_type_ = SupportedData(data_type) # new variable for mypy
59
+ # error if no write func.
60
+ if data_type_ not in WRITE_FUNCS:
61
+ raise NotImplementedError(f"No write function for data type '{data_type}'.")
62
+
63
+ return WRITE_FUNCS[data_type_]
@@ -0,0 +1,40 @@
1
+ """Write tiff function."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+
6
+ import tifffile
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+
12
+ def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
13
+ # TODO: add link to tiffile docs for args kwrgs?
14
+ """
15
+ Write tiff files.
16
+
17
+ Parameters
18
+ ----------
19
+ file_path : pathlib.Path
20
+ Path to file.
21
+ img : numpy.ndarray
22
+ Image data to save.
23
+ *args
24
+ Positional arguments passed to `tifffile.imwrite`.
25
+ **kwargs
26
+ Keyword arguments passed to `tifffile.imwrite`.
27
+
28
+ Raises
29
+ ------
30
+ ValueError
31
+ When the file extension of `file_path` does not match the Unix shell-style
32
+ pattern '*.tif*'.
33
+ """
34
+ if not fnmatch(
35
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
36
+ ):
37
+ raise ValueError(
38
+ f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
39
+ )
40
+ tifffile.imwrite(file_path, img, *args, **kwargs)
@@ -0,0 +1,18 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ __all__ = [
4
+ "FCNModule",
5
+ "VAEModule",
6
+ "create_careamics_module",
7
+ "TrainDataModule",
8
+ "create_train_datamodule",
9
+ "PredictDataModule",
10
+ "create_predict_datamodule",
11
+ "HyperParametersCallback",
12
+ "ProgressBarCallback",
13
+ ]
14
+
15
+ from .callbacks import HyperParametersCallback, ProgressBarCallback
16
+ from .lightning_module import FCNModule, VAEModule, create_careamics_module
17
+ from .predict_data_module import PredictDataModule, create_predict_datamodule
18
+ from .train_data_module import TrainDataModule, create_train_datamodule
@@ -0,0 +1,11 @@
1
+ """Callbacks module."""
2
+
3
+ __all__ = [
4
+ "HyperParametersCallback",
5
+ "ProgressBarCallback",
6
+ "PredictionWriterCallback",
7
+ ]
8
+
9
+ from .hyperparameters_callback import HyperParametersCallback
10
+ from .prediction_writer_callback import PredictionWriterCallback
11
+ from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,49 @@
1
+ """Callback saving CAREamics configuration as hyperparameters in the model."""
2
+
3
+ from pytorch_lightning import LightningModule, Trainer
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+ from careamics.config import Configuration
7
+
8
+
9
+ class HyperParametersCallback(Callback):
10
+ """
11
+ Callback allowing saving CAREamics configuration as hyperparameters in the model.
12
+
13
+ This allows saving the configuration as dictionary in the checkpoints, and
14
+ loading it subsequently in a CAREamist instance.
15
+
16
+ Parameters
17
+ ----------
18
+ config : Configuration
19
+ CAREamics configuration to be saved as hyperparameter in the model.
20
+
21
+ Attributes
22
+ ----------
23
+ config : Configuration
24
+ CAREamics configuration to be saved as hyperparameter in the model.
25
+ """
26
+
27
+ def __init__(self, config: Configuration) -> None:
28
+ """
29
+ Constructor.
30
+
31
+ Parameters
32
+ ----------
33
+ config : Configuration
34
+ CAREamics configuration to be saved as hyperparameter in the model.
35
+ """
36
+ self.config = config
37
+
38
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
39
+ """
40
+ Update the hyperparameters of the model with the configuration on train start.
41
+
42
+ Parameters
43
+ ----------
44
+ trainer : Trainer
45
+ PyTorch Lightning trainer, unused.
46
+ pl_module : LightningModule
47
+ PyTorch Lightning module.
48
+ """
49
+ pl_module.hparams.update(self.config.model_dump())
@@ -0,0 +1,20 @@
1
+ """A package for the `PredictionWriterCallback` class and utilities."""
2
+
3
+ __all__ = [
4
+ "PredictionWriterCallback",
5
+ "create_write_strategy",
6
+ "WriteStrategy",
7
+ "WriteImage",
8
+ "CacheTiles",
9
+ "WriteTilesZarr",
10
+ "select_write_extension",
11
+ "select_write_func",
12
+ ]
13
+
14
+ from .prediction_writer_callback import PredictionWriterCallback
15
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
16
+ from .write_strategy_factory import (
17
+ create_write_strategy,
18
+ select_write_extension,
19
+ select_write_func,
20
+ )
@@ -0,0 +1,56 @@
1
+ """Module containing file path utilities for `WriteStrategy` to use."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
7
+
8
+
9
+ # TODO: move to datasets package ?
10
+ def get_sample_file_path(
11
+ dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int
12
+ ) -> Path:
13
+ """
14
+ Get the file path for a particular sample.
15
+
16
+ Parameters
17
+ ----------
18
+ dataset : IterableTiledPredDataset or IterablePredDataset
19
+ Dataset.
20
+ sample_id : int
21
+ Sample ID, the index of the file in the dataset `dataset`.
22
+
23
+ Returns
24
+ -------
25
+ Path
26
+ The file path corresponding to the sample with the ID `sample_id`.
27
+ """
28
+ return dataset.data_files[sample_id]
29
+
30
+
31
+ def create_write_file_path(
32
+ dirpath: Path, file_path: Path, write_extension: str
33
+ ) -> Path:
34
+ """
35
+ Create the file name for the output file.
36
+
37
+ Takes the original file path, changes the directory to `dirpath` and changes
38
+ the extension to `write_extension`.
39
+
40
+ Parameters
41
+ ----------
42
+ dirpath : pathlib.Path
43
+ The output directory to write file to.
44
+ file_path : pathlib.Path
45
+ The original file path.
46
+ write_extension : str
47
+ The extension that output files should have.
48
+
49
+ Returns
50
+ -------
51
+ Path
52
+ The output file path.
53
+ """
54
+ file_name = Path(file_path.stem).with_suffix(write_extension)
55
+ file_path = dirpath / file_name
56
+ return file_path