careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,292 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import pytorch_lightning as L
4
+ from torch import Tensor, nn
5
+
6
+ from careamics.config import AlgorithmConfig
7
+ from careamics.config.support import (
8
+ SupportedAlgorithm,
9
+ SupportedArchitecture,
10
+ SupportedLoss,
11
+ SupportedOptimizer,
12
+ SupportedScheduler,
13
+ )
14
+ from careamics.losses import loss_factory
15
+ from careamics.models.model_factory import model_factory
16
+ from careamics.transforms import Denormalize, ImageRestorationTTA
17
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
18
+
19
+
20
+ class CAREamicsModule(L.LightningModule):
21
+ """
22
+ CAREamics Lightning module.
23
+
24
+ This class encapsulates the a PyTorch model along with the training, validation,
25
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
26
+
27
+ Attributes
28
+ ----------
29
+ model : nn.Module
30
+ PyTorch model.
31
+ loss_func : nn.Module
32
+ Loss function.
33
+ optimizer_name : str
34
+ Optimizer name.
35
+ optimizer_params : dict
36
+ Optimizer parameters.
37
+ lr_scheduler_name : str
38
+ Learning rate scheduler name.
39
+ """
40
+
41
+ def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
+ """
43
+ CAREamics Lightning module.
44
+
45
+ This class encapsulates the a PyTorch model along with the training, validation,
46
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
47
+
48
+ Parameters
49
+ ----------
50
+ algorithm_config : Union[AlgorithmModel, dict]
51
+ Algorithm configuration.
52
+ """
53
+ super().__init__()
54
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
55
+ if isinstance(algorithm_config, dict):
56
+ algorithm_config = AlgorithmConfig(**algorithm_config)
57
+
58
+ # create model and loss function
59
+ self.model: nn.Module = model_factory(algorithm_config.model)
60
+ self.loss_func = loss_factory(algorithm_config.loss)
61
+
62
+ # save optimizer and lr_scheduler names and parameters
63
+ self.optimizer_name = algorithm_config.optimizer.name
64
+ self.optimizer_params = algorithm_config.optimizer.parameters
65
+ self.lr_scheduler_name = algorithm_config.lr_scheduler.name
66
+ self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
67
+
68
+ def forward(self, x: Any) -> Any:
69
+ """Forward pass.
70
+
71
+ Parameters
72
+ ----------
73
+ x : Any
74
+ Input tensor.
75
+
76
+ Returns
77
+ -------
78
+ Any
79
+ Output tensor.
80
+ """
81
+ return self.model(x)
82
+
83
+ def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
84
+ """Training step.
85
+
86
+ Parameters
87
+ ----------
88
+ batch : Tensor
89
+ Input batch.
90
+ batch_idx : Any
91
+ Batch index.
92
+
93
+ Returns
94
+ -------
95
+ Any
96
+ Loss value.
97
+ """
98
+ x, *aux = batch
99
+ out = self.model(x)
100
+ loss = self.loss_func(out, *aux)
101
+ self.log(
102
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
103
+ )
104
+ return loss
105
+
106
+ def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
107
+ """Validation step.
108
+
109
+ Parameters
110
+ ----------
111
+ batch : Tensor
112
+ Input batch.
113
+ batch_idx : Any
114
+ Batch index.
115
+ """
116
+ x, *aux = batch
117
+ out = self.model(x)
118
+ val_loss = self.loss_func(out, *aux)
119
+
120
+ # log validation loss
121
+ self.log(
122
+ "val_loss",
123
+ val_loss,
124
+ on_step=False,
125
+ on_epoch=True,
126
+ prog_bar=True,
127
+ logger=True,
128
+ )
129
+
130
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
131
+ """Prediction step.
132
+
133
+ Parameters
134
+ ----------
135
+ batch : Tensor
136
+ Input batch.
137
+ batch_idx : Any
138
+ Batch index.
139
+
140
+ Returns
141
+ -------
142
+ Any
143
+ Model output.
144
+ """
145
+ x, *aux = batch
146
+
147
+ # apply test-time augmentation if available
148
+ # TODO: probably wont work with batch size > 1
149
+ if self._trainer.datamodule.prediction_config.tta_transforms:
150
+ tta = ImageRestorationTTA()
151
+ augmented_batch = tta.forward(batch[0]) # list of augmented tensors
152
+ augmented_output = []
153
+ for augmented in augmented_batch:
154
+ augmented_pred = self.model(augmented)
155
+ augmented_output.append(augmented_pred)
156
+ output = tta.backward(augmented_output)
157
+ else:
158
+ output = self.model(x)
159
+
160
+ # Denormalize the output
161
+ denorm = Denormalize(
162
+ mean=self._trainer.datamodule.predict_dataset.mean,
163
+ std=self._trainer.datamodule.predict_dataset.std,
164
+ )
165
+ denormalized_output = denorm(image=output)["image"]
166
+
167
+ if len(aux) > 0:
168
+ return denormalized_output, aux
169
+ else:
170
+ return denormalized_output
171
+
172
+ def configure_optimizers(self) -> Any:
173
+ """Configure optimizers and learning rate schedulers.
174
+
175
+ Returns
176
+ -------
177
+ Any
178
+ Optimizer and learning rate scheduler.
179
+ """
180
+ # instantiate optimizer
181
+ optimizer_func = get_optimizer(self.optimizer_name)
182
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
183
+
184
+ # and scheduler
185
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
186
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
187
+
188
+ return {
189
+ "optimizer": optimizer,
190
+ "lr_scheduler": scheduler,
191
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
192
+ }
193
+
194
+
195
+ class CAREamicsModuleWrapper(CAREamicsModule):
196
+ """Class defining the API for CAREamics Lightning layer.
197
+
198
+ This class exposes parameters used to create an AlgorithmModel instance, triggering
199
+ parameters validation.
200
+
201
+ Parameters
202
+ ----------
203
+ algorithm : Union[SupportedAlgorithm, str]
204
+ Algorithm to use for training (see SupportedAlgorithm).
205
+ loss : Union[SupportedLoss, str]
206
+ Loss function to use for training (see SupportedLoss).
207
+ architecture : Union[SupportedArchitecture, str]
208
+ Model architecture to use for training (see SupportedArchitecture).
209
+ model_parameters : dict, optional
210
+ Model parameters to use for training, by default {}. Model parameters are
211
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
212
+ `careamics.config.architectures`).
213
+ optimizer : Union[SupportedOptimizer, str], optional
214
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
215
+ optimizer_parameters : dict, optional
216
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
217
+ default {}.
218
+ lr_scheduler : Union[SupportedScheduler, str], optional
219
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
220
+ (see SupportedScheduler).
221
+ lr_scheduler_parameters : dict, optional
222
+ Learning rate scheduler parameters to use for training, as defined in
223
+ `torch.optim`, by default {}.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ algorithm: Union[SupportedAlgorithm, str],
229
+ loss: Union[SupportedLoss, str],
230
+ architecture: Union[SupportedArchitecture, str],
231
+ model_parameters: Optional[dict] = None,
232
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
233
+ optimizer_parameters: Optional[dict] = None,
234
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
235
+ lr_scheduler_parameters: Optional[dict] = None,
236
+ ) -> None:
237
+ """
238
+ Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
239
+
240
+ Parameters
241
+ ----------
242
+ algorithm : Union[SupportedAlgorithm, str]
243
+ Algorithm to use for training (see SupportedAlgorithm).
244
+ loss : Union[SupportedLoss, str]
245
+ Loss function to use for training (see SupportedLoss).
246
+ architecture : Union[SupportedArchitecture, str]
247
+ Model architecture to use for training (see SupportedArchitecture).
248
+ model_parameters : dict, optional
249
+ Model parameters to use for training, by default {}. Model parameters are
250
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
251
+ `careamics.config.architectures`).
252
+ optimizer : Union[SupportedOptimizer, str], optional
253
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
254
+ optimizer_parameters : dict, optional
255
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
256
+ default {}.
257
+ lr_scheduler : Union[SupportedScheduler, str], optional
258
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
259
+ (see SupportedScheduler).
260
+ lr_scheduler_parameters : dict, optional
261
+ Learning rate scheduler parameters to use for training, as defined in
262
+ `torch.optim`, by default {}.
263
+ """
264
+ # create a AlgorithmModel compatible dictionary
265
+ if lr_scheduler_parameters is None:
266
+ lr_scheduler_parameters = {}
267
+ if optimizer_parameters is None:
268
+ optimizer_parameters = {}
269
+ if model_parameters is None:
270
+ model_parameters = {}
271
+ algorithm_configuration = {
272
+ "algorithm": algorithm,
273
+ "loss": loss,
274
+ "optimizer": {
275
+ "name": optimizer,
276
+ "parameters": optimizer_parameters,
277
+ },
278
+ "lr_scheduler": {
279
+ "name": lr_scheduler,
280
+ "parameters": lr_scheduler_parameters,
281
+ },
282
+ }
283
+ model_configuration = {"architecture": architecture}
284
+ model_configuration.update(model_parameters)
285
+
286
+ # add model parameters to algorithm configuration
287
+ algorithm_configuration["model"] = model_configuration
288
+
289
+ # call the parent init using an AlgorithmModel instance
290
+ super().__init__(AlgorithmConfig(**algorithm_configuration))
291
+
292
+ # TODO add load_from_checkpoint wrapper
@@ -0,0 +1,396 @@
1
+ """Prediction Lightning data modules."""
2
+ from pathlib import Path
3
+ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as L
7
+ from albumentations import Compose
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data.dataloader import default_collate
10
+
11
+ from careamics.config import InferenceConfig
12
+ from careamics.config.support import SupportedData
13
+ from careamics.config.tile_information import TileInformation
14
+ from careamics.dataset.dataset_utils import (
15
+ get_read_func,
16
+ list_files,
17
+ )
18
+ from careamics.dataset.in_memory_dataset import (
19
+ InMemoryPredictionDataset,
20
+ )
21
+ from careamics.dataset.iterable_dataset import (
22
+ IterablePredictionDataset,
23
+ )
24
+ from careamics.utils import get_logger
25
+
26
+ PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
32
+ """
33
+ Collate tiles received from CAREamics prediction dataloader.
34
+
35
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
36
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
37
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
38
+ stitch coordinates.
39
+
40
+ Parameters
41
+ ----------
42
+ batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
43
+ Batch of tiles.
44
+
45
+ Returns
46
+ -------
47
+ Any
48
+ Collated batch.
49
+ """
50
+ first_tile_info: TileInformation = batch[0][1]
51
+ # if not tiled, then return arrays
52
+ if not first_tile_info.tiled:
53
+ arrays, _ = zip(*batch)
54
+
55
+ return default_collate(arrays)
56
+ # else we explicit the last_tile flag and coordinates
57
+ else:
58
+ new_batch = [
59
+ (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
60
+ for tile, t in batch
61
+ ]
62
+
63
+ return default_collate(new_batch)
64
+
65
+
66
+ class CAREamicsPredictData(L.LightningDataModule):
67
+ """
68
+ CAREamics Lightning prediction data module.
69
+
70
+ The data module can be used with Path, str or numpy arrays. The data can be either
71
+ a folder containing images or a single file.
72
+
73
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
74
+ and provide a function that returns a numpy array from a path as
75
+ `read_source_func` parameter. The function will receive a Path object and
76
+ an axies string as arguments, the axes being derived from the `data_config`.
77
+
78
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
79
+ "*.czi") to filter the files extension using `extension_filter`.
80
+
81
+ Parameters
82
+ ----------
83
+ pred_config : InferenceModel
84
+ Pydantic model for CAREamics prediction configuration.
85
+ pred_data : Union[Path, str, np.ndarray]
86
+ Prediction data, can be a path to a folder, a file or a numpy array.
87
+ read_source_func : Optional[Callable], optional
88
+ Function to read custom types, by default None.
89
+ extension_filter : str, optional
90
+ Filter to filter file extensions for custom types, by default "".
91
+ dataloader_params : dict, optional
92
+ Dataloader parameters, by default {}.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ pred_config: InferenceConfig,
98
+ pred_data: Union[Path, str, np.ndarray],
99
+ read_source_func: Optional[Callable] = None,
100
+ extension_filter: str = "",
101
+ dataloader_params: Optional[dict] = None,
102
+ ) -> None:
103
+ """
104
+ Constructor.
105
+
106
+ The data module can be used with Path, str or numpy arrays. The data can be
107
+ either a folder containing images or a single file.
108
+
109
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
110
+ and provide a function that returns a numpy array from a path as
111
+ `read_source_func` parameter. The function will receive a Path object and
112
+ an axies string as arguments, the axes being derived from the `data_config`.
113
+
114
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
115
+ "*.czi") to filter the files extension using `extension_filter`.
116
+
117
+ Parameters
118
+ ----------
119
+ pred_config : InferenceModel
120
+ Pydantic model for CAREamics prediction configuration.
121
+ pred_data : Union[Path, str, np.ndarray]
122
+ Prediction data, can be a path to a folder, a file or a numpy array.
123
+ read_source_func : Optional[Callable], optional
124
+ Function to read custom types, by default None.
125
+ extension_filter : str, optional
126
+ Filter to filter file extensions for custom types, by default "".
127
+ dataloader_params : dict, optional
128
+ Dataloader parameters, by default {}.
129
+
130
+ Raises
131
+ ------
132
+ ValueError
133
+ If the data type is `custom` and no `read_source_func` is provided.
134
+ ValueError
135
+ If the data type is `array` and the input is not a numpy array.
136
+ ValueError
137
+ If the data type is `tiff` and the input is neither a Path nor a str.
138
+ """
139
+ if dataloader_params is None:
140
+ dataloader_params = {}
141
+ if dataloader_params is None:
142
+ dataloader_params = {}
143
+ super().__init__()
144
+
145
+ # check that a read source function is provided for custom types
146
+ if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
147
+ raise ValueError(
148
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
149
+ f"specifying a `read_source_func` and an `extension_filer`."
150
+ )
151
+
152
+ # check correct input type
153
+ if (
154
+ isinstance(pred_data, np.ndarray)
155
+ and pred_config.data_type != SupportedData.ARRAY
156
+ ):
157
+ raise ValueError(
158
+ f"Received a numpy array as input, but the data type was set to "
159
+ f"{pred_config.data_type}. Set the data type "
160
+ f"to {SupportedData.ARRAY} to predict on numpy arrays."
161
+ )
162
+
163
+ # and that Path or str are passed, if tiff file type specified
164
+ elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
165
+ pred_config.data_type != SupportedData.TIFF
166
+ and pred_config.data_type != SupportedData.CUSTOM
167
+ ):
168
+ raise ValueError(
169
+ f"Received a path as input, but the data type was neither set to "
170
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
171
+ f" to {SupportedData.TIFF} or "
172
+ f"{SupportedData.CUSTOM} to predict on files."
173
+ )
174
+
175
+ # configuration data
176
+ self.prediction_config = pred_config
177
+ self.data_type = pred_config.data_type
178
+ self.batch_size = pred_config.batch_size
179
+ self.dataloader_params = dataloader_params
180
+
181
+ self.pred_data = pred_data
182
+ self.tile_size = pred_config.tile_size
183
+ self.tile_overlap = pred_config.tile_overlap
184
+
185
+ # read source function
186
+ if pred_config.data_type == SupportedData.CUSTOM:
187
+ # mypy check
188
+ assert read_source_func is not None
189
+
190
+ self.read_source_func: Callable = read_source_func
191
+ elif pred_config.data_type != SupportedData.ARRAY:
192
+ self.read_source_func = get_read_func(pred_config.data_type)
193
+
194
+ self.extension_filter = extension_filter
195
+
196
+ def prepare_data(self) -> None:
197
+ """Hook used to prepare the data before calling `setup`."""
198
+ # if the data is a Path or a str
199
+ if not isinstance(self.pred_data, np.ndarray):
200
+ self.pred_files = list_files(
201
+ self.pred_data, self.data_type, self.extension_filter
202
+ )
203
+
204
+ def setup(self, stage: Optional[str] = None) -> None:
205
+ """
206
+ Hook called at the beginning of predict.
207
+
208
+ Parameters
209
+ ----------
210
+ stage : Optional[str], optional
211
+ Stage, by default None.
212
+ """
213
+ # if numpy array
214
+ if self.data_type == SupportedData.ARRAY:
215
+ # prediction dataset
216
+ self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
217
+ prediction_config=self.prediction_config,
218
+ inputs=self.pred_data,
219
+ )
220
+ else:
221
+ self.predict_dataset = IterablePredictionDataset(
222
+ prediction_config=self.prediction_config,
223
+ src_files=self.pred_files,
224
+ read_source_func=self.read_source_func,
225
+ )
226
+
227
+ def predict_dataloader(self) -> DataLoader:
228
+ """
229
+ Create a dataloader for prediction.
230
+
231
+ Returns
232
+ -------
233
+ DataLoader
234
+ Prediction dataloader.
235
+ """
236
+ return DataLoader(
237
+ self.predict_dataset,
238
+ batch_size=self.batch_size,
239
+ collate_fn=_collate_tiles,
240
+ **self.dataloader_params,
241
+ ) # TODO check workers are used
242
+
243
+
244
+ class PredictDataWrapper(CAREamicsPredictData):
245
+ """
246
+ Wrapper around the CAREamics inference Lightning data module.
247
+
248
+ This class is used to explicitely pass the parameters usually contained in a
249
+ `inference_model` configuration.
250
+
251
+ Since the lightning datamodule has no access to the model, make sure that the
252
+ parameters passed to the datamodule are consistent with the model's requirements
253
+ and are coherent.
254
+
255
+ The data module can be used with Path, str or numpy arrays. To use array data, set
256
+ `data_type` to `array` and pass a numpy array to `train_data`.
257
+
258
+ The default transformations applied to the images are defined in
259
+ `careamics.config.inference_model`. To use different transformations, pass a list
260
+ of transforms or an albumentation `Compose` as `transforms` parameter. See examples
261
+ for more details.
262
+
263
+ The `mean` and `std` parameters are only used if Normalization is defined either
264
+ in the default transformations or in the `transforms` parameter, but not with
265
+ a `Compose` object. If you pass a `Normalization` transform in a list as
266
+ `transforms`, then the mean and std parameters will be overwritten by those passed
267
+ to this method.
268
+
269
+ By default, CAREamics only supports types defined in
270
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
271
+ `data_type` to `custom` and provide a function that returns a numpy array from a
272
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
273
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
274
+
275
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
276
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
277
+ parameter.
278
+
279
+ Parameters
280
+ ----------
281
+ pred_data : Union[str, Path, np.ndarray]
282
+ Prediction data.
283
+ data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
284
+ Data type, see `SupportedData` for available options.
285
+ mean : float
286
+ Mean value for normalization, only used if Normalization is defined in the
287
+ transforms.
288
+ std : float
289
+ Standard deviation value for normalization, only used if Normalization is
290
+ defined in the transform.
291
+ tile_size : Tuple[int, ...]
292
+ Tile size, 2D or 3D tile size.
293
+ tile_overlap : Tuple[int, ...]
294
+ Tile overlap, 2D or 3D tile overlap.
295
+ axes : str
296
+ Axes of the data, choosen amongst SCZYX.
297
+ batch_size : int
298
+ Batch size.
299
+ tta_transforms : bool, optional
300
+ Use test time augmentation, by default True.
301
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
302
+ List of transforms to apply to prediction patches. If None, default
303
+ transforms are applied.
304
+ read_source_func : Optional[Callable], optional
305
+ Function to read the source data, used if `data_type` is `custom`, by
306
+ default None.
307
+ extension_filter : str, optional
308
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
309
+ dataloader_params : dict, optional
310
+ Pytorch dataloader parameters, by default {}.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ pred_data: Union[str, Path, np.ndarray],
316
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
317
+ mean: float,
318
+ std: float,
319
+ tile_size: Optional[Tuple[int, ...]] = None,
320
+ tile_overlap: Optional[Tuple[int, ...]] = None,
321
+ axes: str = "YX",
322
+ batch_size: int = 1,
323
+ tta_transforms: bool = True,
324
+ transforms: Optional[Union[List, Compose]] = None,
325
+ read_source_func: Optional[Callable] = None,
326
+ extension_filter: str = "",
327
+ dataloader_params: Optional[dict] = None,
328
+ ) -> None:
329
+ """
330
+ Constructor.
331
+
332
+ Parameters
333
+ ----------
334
+ pred_data : Union[str, Path, np.ndarray]
335
+ Prediction data.
336
+ data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
337
+ Data type, see `SupportedData` for available options.
338
+ mean : float
339
+ Mean value for normalization, only used if Normalization is defined in the
340
+ transforms.
341
+ std : float
342
+ Standard deviation value for normalization, only used if Normalization is
343
+ defined in the transform.
344
+ tile_size : List[int]
345
+ Tile size, 2D or 3D tile size.
346
+ tile_overlap : List[int]
347
+ Tile overlap, 2D or 3D tile overlap.
348
+ axes : str
349
+ Axes of the data, choosen amongst SCZYX.
350
+ batch_size : int
351
+ Batch size.
352
+ tta_transforms : bool, optional
353
+ Use test time augmentation, by default True.
354
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
355
+ List of transforms to apply to prediction patches. If None, default
356
+ transforms are applied.
357
+ read_source_func : Optional[Callable], optional
358
+ Function to read the source data, used if `data_type` is `custom`, by
359
+ default None.
360
+ extension_filter : str, optional
361
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
362
+ dataloader_params : dict, optional
363
+ Pytorch dataloader parameters, by default {}.
364
+ """
365
+ if dataloader_params is None:
366
+ dataloader_params = {}
367
+ prediction_dict = {
368
+ "data_type": data_type,
369
+ "tile_size": tile_size,
370
+ "tile_overlap": tile_overlap,
371
+ "axes": axes,
372
+ "mean": mean,
373
+ "std": std,
374
+ "tta": tta_transforms,
375
+ "batch_size": batch_size,
376
+ }
377
+
378
+ # if transforms are passed (otherwise it will use the default ones)
379
+ if transforms is not None:
380
+ prediction_dict["transforms"] = transforms
381
+
382
+ # validate configuration
383
+ self.prediction_config = InferenceConfig(**prediction_dict)
384
+
385
+ # sanity check on the dataloader parameters
386
+ if "batch_size" in dataloader_params:
387
+ # remove it
388
+ del dataloader_params["batch_size"]
389
+
390
+ super().__init__(
391
+ pred_config=self.prediction_config,
392
+ pred_data=pred_data,
393
+ read_source_func=read_source_func,
394
+ extension_filter=extension_filter,
395
+ dataloader_params=dataloader_params,
396
+ )