careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,333 @@
1
+ """Prediction Lightning data modules."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ import pytorch_lightning as L
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import DataLoader
10
+
11
+ from careamics.config import InferenceConfig
12
+ from careamics.config.support import SupportedData
13
+ from careamics.dataset import (
14
+ InMemoryPredDataset,
15
+ InMemoryTiledPredDataset,
16
+ IterablePredDataset,
17
+ IterableTiledPredDataset,
18
+ )
19
+ from careamics.dataset.dataset_utils import list_files
20
+ from careamics.dataset.tiling.collate_tiles import collate_tiles
21
+ from careamics.file_io.read import get_read_func
22
+ from careamics.utils import get_logger
23
+
24
+ PredictDatasetType = Union[
25
+ InMemoryPredDataset,
26
+ InMemoryTiledPredDataset,
27
+ IterablePredDataset,
28
+ IterableTiledPredDataset,
29
+ ]
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class PredictDataModule(L.LightningDataModule):
35
+ """
36
+ CAREamics Lightning prediction data module.
37
+
38
+ The data module can be used with Path, str or numpy arrays. The data can be either
39
+ a folder containing images or a single file.
40
+
41
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
42
+ and provide a function that returns a numpy array from a path as
43
+ `read_source_func` parameter. The function will receive a Path object and
44
+ an axies string as arguments, the axes being derived from the `data_config`.
45
+
46
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
47
+ "*.czi") to filter the files extension using `extension_filter`.
48
+
49
+ Parameters
50
+ ----------
51
+ pred_config : InferenceModel
52
+ Pydantic model for CAREamics prediction configuration.
53
+ pred_data : pathlib.Path or str or numpy.ndarray
54
+ Prediction data, can be a path to a folder, a file or a numpy array.
55
+ read_source_func : Callable, optional
56
+ Function to read custom types, by default None.
57
+ extension_filter : str, optional
58
+ Filter to filter file extensions for custom types, by default "".
59
+ dataloader_params : dict, optional
60
+ Dataloader parameters, by default {}.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ pred_config: InferenceConfig,
66
+ pred_data: Union[Path, str, NDArray],
67
+ read_source_func: Optional[Callable] = None,
68
+ extension_filter: str = "",
69
+ dataloader_params: Optional[dict] = None,
70
+ ) -> None:
71
+ """
72
+ Constructor.
73
+
74
+ The data module can be used with Path, str or numpy arrays. The data can be
75
+ either a folder containing images or a single file.
76
+
77
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
78
+ and provide a function that returns a numpy array from a path as
79
+ `read_source_func` parameter. The function will receive a Path object and
80
+ an axies string as arguments, the axes being derived from the `data_config`.
81
+
82
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
83
+ "*.czi") to filter the files extension using `extension_filter`.
84
+
85
+ Parameters
86
+ ----------
87
+ pred_config : InferenceModel
88
+ Pydantic model for CAREamics prediction configuration.
89
+ pred_data : pathlib.Path or str or numpy.ndarray
90
+ Prediction data, can be a path to a folder, a file or a numpy array.
91
+ read_source_func : Callable, optional
92
+ Function to read custom types, by default None.
93
+ extension_filter : str, optional
94
+ Filter to filter file extensions for custom types, by default "".
95
+ dataloader_params : dict, optional
96
+ Dataloader parameters, by default {}.
97
+
98
+ Raises
99
+ ------
100
+ ValueError
101
+ If the data type is `custom` and no `read_source_func` is provided.
102
+ ValueError
103
+ If the data type is `array` and the input is not a numpy array.
104
+ ValueError
105
+ If the data type is `tiff` and the input is neither a Path nor a str.
106
+ """
107
+ if dataloader_params is None:
108
+ dataloader_params = {}
109
+ if dataloader_params is None:
110
+ dataloader_params = {}
111
+ super().__init__()
112
+
113
+ # check that a read source function is provided for custom types
114
+ if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
115
+ raise ValueError(
116
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
117
+ f"specifying a `read_source_func` and an `extension_filer`."
118
+ )
119
+
120
+ # check correct input type
121
+ if (
122
+ isinstance(pred_data, np.ndarray)
123
+ and pred_config.data_type != SupportedData.ARRAY
124
+ ):
125
+ raise ValueError(
126
+ f"Received a numpy array as input, but the data type was set to "
127
+ f"{pred_config.data_type}. Set the data type "
128
+ f"to {SupportedData.ARRAY} to predict on numpy arrays."
129
+ )
130
+
131
+ # and that Path or str are passed, if tiff file type specified
132
+ elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
133
+ pred_config.data_type != SupportedData.TIFF
134
+ and pred_config.data_type != SupportedData.CUSTOM
135
+ ):
136
+ raise ValueError(
137
+ f"Received a path as input, but the data type was neither set to "
138
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
139
+ f" to {SupportedData.TIFF} or "
140
+ f"{SupportedData.CUSTOM} to predict on files."
141
+ )
142
+
143
+ # configuration data
144
+ self.prediction_config = pred_config
145
+ self.data_type = pred_config.data_type
146
+ self.batch_size = pred_config.batch_size
147
+ self.dataloader_params = dataloader_params
148
+
149
+ self.pred_data = pred_data
150
+ self.tile_size = pred_config.tile_size
151
+ self.tile_overlap = pred_config.tile_overlap
152
+
153
+ # check if it is tiled
154
+ self.tiled = self.tile_size is not None and self.tile_overlap is not None
155
+
156
+ # read source function
157
+ if pred_config.data_type == SupportedData.CUSTOM:
158
+ # mypy check
159
+ assert read_source_func is not None
160
+
161
+ self.read_source_func: Callable = read_source_func
162
+ elif pred_config.data_type != SupportedData.ARRAY:
163
+ self.read_source_func = get_read_func(pred_config.data_type)
164
+
165
+ self.extension_filter = extension_filter
166
+
167
+ def prepare_data(self) -> None:
168
+ """Hook used to prepare the data before calling `setup`."""
169
+ # if the data is a Path or a str
170
+ if not isinstance(self.pred_data, np.ndarray):
171
+ self.pred_files = list_files(
172
+ self.pred_data, self.data_type, self.extension_filter
173
+ )
174
+
175
+ def setup(self, stage: Optional[str] = None) -> None:
176
+ """
177
+ Hook called at the beginning of predict.
178
+
179
+ Parameters
180
+ ----------
181
+ stage : Optional[str], optional
182
+ Stage, by default None.
183
+ """
184
+ # if numpy array
185
+ if self.data_type == SupportedData.ARRAY:
186
+ if self.tiled:
187
+ self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
188
+ prediction_config=self.prediction_config,
189
+ inputs=self.pred_data,
190
+ )
191
+ else:
192
+ self.predict_dataset = InMemoryPredDataset(
193
+ prediction_config=self.prediction_config,
194
+ inputs=self.pred_data,
195
+ )
196
+ else:
197
+ if self.tiled:
198
+ self.predict_dataset = IterableTiledPredDataset(
199
+ prediction_config=self.prediction_config,
200
+ src_files=self.pred_files,
201
+ read_source_func=self.read_source_func,
202
+ )
203
+ else:
204
+ self.predict_dataset = IterablePredDataset(
205
+ prediction_config=self.prediction_config,
206
+ src_files=self.pred_files,
207
+ read_source_func=self.read_source_func,
208
+ )
209
+
210
+ def predict_dataloader(self) -> DataLoader:
211
+ """
212
+ Create a dataloader for prediction.
213
+
214
+ Returns
215
+ -------
216
+ DataLoader
217
+ Prediction dataloader.
218
+ """
219
+ return DataLoader(
220
+ self.predict_dataset,
221
+ batch_size=self.batch_size,
222
+ collate_fn=collate_tiles if self.tiled else None,
223
+ **self.dataloader_params,
224
+ )
225
+
226
+
227
+ def create_predict_datamodule(
228
+ pred_data: Union[str, Path, NDArray],
229
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
230
+ axes: str,
231
+ image_means: list[float],
232
+ image_stds: list[float],
233
+ tile_size: Optional[tuple[int, ...]] = None,
234
+ tile_overlap: Optional[tuple[int, ...]] = None,
235
+ batch_size: int = 1,
236
+ tta_transforms: bool = True,
237
+ read_source_func: Optional[Callable] = None,
238
+ extension_filter: str = "",
239
+ dataloader_params: Optional[dict] = None,
240
+ ) -> PredictDataModule:
241
+ """Create a CAREamics prediction Lightning datamodule.
242
+
243
+ This function is used to explicitly pass the parameters usually contained in an
244
+ `inference_model` configuration.
245
+
246
+ Since the lightning datamodule has no access to the model, make sure that the
247
+ parameters passed to the datamodule are consistent with the model's requirements
248
+ and are coherent. This can be done by creating a `Configuration` object beforehand
249
+ and passing its parameters to the different Lightning modules.
250
+
251
+ The data module can be used with Path, str or numpy arrays. To use array data, set
252
+ `data_type` to `array` and pass a numpy array to `train_data`.
253
+
254
+ By default, CAREamics only supports types defined in
255
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
256
+ `data_type` to `custom` and provide a function that returns a numpy array from a
257
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
258
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
259
+
260
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
261
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
262
+ parameter.
263
+
264
+ Parameters
265
+ ----------
266
+ pred_data : str or pathlib.Path or numpy.ndarray
267
+ Prediction data.
268
+ data_type : {"array", "tiff", "custom"}
269
+ Data type, see `SupportedData` for available options.
270
+ axes : str
271
+ Axes of the data, chosen among SCZYX.
272
+ image_means : list of float
273
+ Mean values for normalization, only used if Normalization is defined.
274
+ image_stds : list of float
275
+ Std values for normalization, only used if Normalization is defined.
276
+ tile_size : tuple of int, optional
277
+ Tile size, 2D or 3D tile size.
278
+ tile_overlap : tuple of int, optional
279
+ Tile overlap, 2D or 3D tile overlap.
280
+ batch_size : int
281
+ Batch size.
282
+ tta_transforms : bool, optional
283
+ Use test time augmentation, by default True.
284
+ read_source_func : Callable, optional
285
+ Function to read the source data, used if `data_type` is `custom`, by
286
+ default None.
287
+ extension_filter : str, optional
288
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
289
+ dataloader_params : dict, optional
290
+ Pytorch dataloader parameters, by default {}.
291
+
292
+ Returns
293
+ -------
294
+ PredictDataModule
295
+ CAREamics prediction datamodule.
296
+
297
+ Notes
298
+ -----
299
+ If you are using a UNet model and tiling, the tile size must be
300
+ divisible in every dimension by 2**d, where d is the depth of the model. This
301
+ avoids artefacts arising from the broken shift invariance induced by the
302
+ pooling layers of the UNet. If your image has less dimensions, as it may
303
+ happen in the Z dimension, consider padding your image.
304
+ """
305
+ if dataloader_params is None:
306
+ dataloader_params = {}
307
+
308
+ prediction_dict: dict[str, Any] = {
309
+ "data_type": data_type,
310
+ "tile_size": tile_size,
311
+ "tile_overlap": tile_overlap,
312
+ "axes": axes,
313
+ "image_means": image_means,
314
+ "image_stds": image_stds,
315
+ "tta_transforms": tta_transforms,
316
+ "batch_size": batch_size,
317
+ }
318
+
319
+ # validate configuration
320
+ prediction_config = InferenceConfig(**prediction_dict)
321
+
322
+ # sanity check on the dataloader parameters
323
+ if "batch_size" in dataloader_params:
324
+ # remove it
325
+ del dataloader_params["batch_size"]
326
+
327
+ return PredictDataModule(
328
+ pred_config=prediction_config,
329
+ pred_data=pred_data,
330
+ read_source_func=read_source_func,
331
+ extension_filter=extension_filter,
332
+ dataloader_params=dataloader_params,
333
+ )