careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,292 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ import pytorch_lightning as L
4
+ from torch import Tensor, nn
5
+
6
+ from careamics.config import AlgorithmModel
7
+ from careamics.config.support import (
8
+ SupportedAlgorithm,
9
+ SupportedArchitecture,
10
+ SupportedLoss,
11
+ SupportedOptimizer,
12
+ SupportedScheduler,
13
+ )
14
+ from careamics.losses import loss_factory
15
+ from careamics.models.model_factory import model_factory
16
+ from careamics.transforms import Denormalize, ImageRestorationTTA
17
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
18
+
19
+
20
+ class CAREamicsKiln(L.LightningModule):
21
+ """
22
+ CAREamics Lightning module.
23
+
24
+ This class encapsulates the a PyTorch model along with the training, validation,
25
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
26
+
27
+ Attributes
28
+ ----------
29
+ model : nn.Module
30
+ PyTorch model.
31
+ loss_func : nn.Module
32
+ Loss function.
33
+ optimizer_name : str
34
+ Optimizer name.
35
+ optimizer_params : dict
36
+ Optimizer parameters.
37
+ lr_scheduler_name : str
38
+ Learning rate scheduler name.
39
+ """
40
+
41
+ def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None:
42
+ """
43
+ CAREamics Lightning module.
44
+
45
+ This class encapsulates the a PyTorch model along with the training, validation,
46
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
47
+
48
+ Parameters
49
+ ----------
50
+ algorithm_config : Union[AlgorithmModel, dict]
51
+ Algorithm configuration.
52
+ """
53
+ super().__init__()
54
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
55
+ if isinstance(algorithm_config, dict):
56
+ algorithm_config = AlgorithmModel(**algorithm_config)
57
+
58
+ # create model and loss function
59
+ self.model: nn.Module = model_factory(algorithm_config.model)
60
+ self.loss_func = loss_factory(algorithm_config.loss)
61
+
62
+ # save optimizer and lr_scheduler names and parameters
63
+ self.optimizer_name = algorithm_config.optimizer.name
64
+ self.optimizer_params = algorithm_config.optimizer.parameters
65
+ self.lr_scheduler_name = algorithm_config.lr_scheduler.name
66
+ self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
67
+
68
+ def forward(self, x: Any) -> Any:
69
+ """Forward pass.
70
+
71
+ Parameters
72
+ ----------
73
+ x : Any
74
+ Input tensor.
75
+
76
+ Returns
77
+ -------
78
+ Any
79
+ Output tensor.
80
+ """
81
+ return self.model(x)
82
+
83
+ def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
84
+ """Training step.
85
+
86
+ Parameters
87
+ ----------
88
+ batch : Tensor
89
+ Input batch.
90
+ batch_idx : Any
91
+ Batch index.
92
+
93
+ Returns
94
+ -------
95
+ Any
96
+ Loss value.
97
+ """
98
+ x, *aux = batch
99
+ out = self.model(x)
100
+ loss = self.loss_func(out, *aux)
101
+ self.log(
102
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
103
+ )
104
+ return loss
105
+
106
+ def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
107
+ """Validation step.
108
+
109
+ Parameters
110
+ ----------
111
+ batch : Tensor
112
+ Input batch.
113
+ batch_idx : Any
114
+ Batch index.
115
+ """
116
+ x, *aux = batch
117
+ out = self.model(x)
118
+ val_loss = self.loss_func(out, *aux)
119
+
120
+ # log validation loss
121
+ self.log(
122
+ "val_loss",
123
+ val_loss,
124
+ on_step=False,
125
+ on_epoch=True,
126
+ prog_bar=True,
127
+ logger=True,
128
+ )
129
+
130
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
131
+ """Prediction step.
132
+
133
+ Parameters
134
+ ----------
135
+ batch : Tensor
136
+ Input batch.
137
+ batch_idx : Any
138
+ Batch index.
139
+
140
+ Returns
141
+ -------
142
+ Any
143
+ Model output.
144
+ """
145
+ x, *aux = batch
146
+
147
+ # apply test-time augmentation if available
148
+ # TODO: probably wont work with batch size > 1
149
+ if self._trainer.datamodule.prediction_config.tta_transforms:
150
+ tta = ImageRestorationTTA()
151
+ augmented_batch = tta.forward(batch[0]) # list of augmented tensors
152
+ augmented_output = []
153
+ for augmented in augmented_batch:
154
+ augmented_pred = self.model(augmented)
155
+ augmented_output.append(augmented_pred)
156
+ output = tta.backward(augmented_output)
157
+ else:
158
+ output = self.model(x)
159
+
160
+ # Denormalize the output
161
+ denorm = Denormalize(
162
+ mean=self._trainer.datamodule.predict_dataset.mean,
163
+ std=self._trainer.datamodule.predict_dataset.std,
164
+ )
165
+ denormalized_output = denorm(image=output)["image"]
166
+
167
+ if len(aux) > 0:
168
+ return denormalized_output, aux
169
+ else:
170
+ return denormalized_output
171
+
172
+ def configure_optimizers(self) -> Any:
173
+ """Configure optimizers and learning rate schedulers.
174
+
175
+ Returns
176
+ -------
177
+ Any
178
+ Optimizer and learning rate scheduler.
179
+ """
180
+ # instantiate optimizer
181
+ optimizer_func = get_optimizer(self.optimizer_name)
182
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
183
+
184
+ # and scheduler
185
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
186
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
187
+
188
+ return {
189
+ "optimizer": optimizer,
190
+ "lr_scheduler": scheduler,
191
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
192
+ }
193
+
194
+
195
+ class CAREamicsModule(CAREamicsKiln):
196
+ """Class defining the API for CAREamics Lightning layer.
197
+
198
+ This class exposes parameters used to create an AlgorithmModel instance, triggering
199
+ parameters validation.
200
+
201
+ Parameters
202
+ ----------
203
+ algorithm : Union[SupportedAlgorithm, str]
204
+ Algorithm to use for training (see SupportedAlgorithm).
205
+ loss : Union[SupportedLoss, str]
206
+ Loss function to use for training (see SupportedLoss).
207
+ architecture : Union[SupportedArchitecture, str]
208
+ Model architecture to use for training (see SupportedArchitecture).
209
+ model_parameters : dict, optional
210
+ Model parameters to use for training, by default {}. Model parameters are
211
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
212
+ `careamics.config.architectures`).
213
+ optimizer : Union[SupportedOptimizer, str], optional
214
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
215
+ optimizer_parameters : dict, optional
216
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
217
+ default {}.
218
+ lr_scheduler : Union[SupportedScheduler, str], optional
219
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
220
+ (see SupportedScheduler).
221
+ lr_scheduler_parameters : dict, optional
222
+ Learning rate scheduler parameters to use for training, as defined in
223
+ `torch.optim`, by default {}.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ algorithm: Union[SupportedAlgorithm, str],
229
+ loss: Union[SupportedLoss, str],
230
+ architecture: Union[SupportedArchitecture, str],
231
+ model_parameters: Optional[dict] = None,
232
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
233
+ optimizer_parameters: Optional[dict] = None,
234
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
235
+ lr_scheduler_parameters: Optional[dict] = None,
236
+ ) -> None:
237
+ """
238
+ Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
239
+
240
+ Parameters
241
+ ----------
242
+ algorithm : Union[SupportedAlgorithm, str]
243
+ Algorithm to use for training (see SupportedAlgorithm).
244
+ loss : Union[SupportedLoss, str]
245
+ Loss function to use for training (see SupportedLoss).
246
+ architecture : Union[SupportedArchitecture, str]
247
+ Model architecture to use for training (see SupportedArchitecture).
248
+ model_parameters : dict, optional
249
+ Model parameters to use for training, by default {}. Model parameters are
250
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
251
+ `careamics.config.architectures`).
252
+ optimizer : Union[SupportedOptimizer, str], optional
253
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
254
+ optimizer_parameters : dict, optional
255
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
256
+ default {}.
257
+ lr_scheduler : Union[SupportedScheduler, str], optional
258
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
259
+ (see SupportedScheduler).
260
+ lr_scheduler_parameters : dict, optional
261
+ Learning rate scheduler parameters to use for training, as defined in
262
+ `torch.optim`, by default {}.
263
+ """
264
+ # create a AlgorithmModel compatible dictionary
265
+ if lr_scheduler_parameters is None:
266
+ lr_scheduler_parameters = {}
267
+ if optimizer_parameters is None:
268
+ optimizer_parameters = {}
269
+ if model_parameters is None:
270
+ model_parameters = {}
271
+ algorithm_configuration = {
272
+ "algorithm": algorithm,
273
+ "loss": loss,
274
+ "optimizer": {
275
+ "name": optimizer,
276
+ "parameters": optimizer_parameters,
277
+ },
278
+ "lr_scheduler": {
279
+ "name": lr_scheduler,
280
+ "parameters": lr_scheduler_parameters,
281
+ },
282
+ }
283
+ model_configuration = {"architecture": architecture}
284
+ model_configuration.update(model_parameters)
285
+
286
+ # add model parameters to algorithm configuration
287
+ algorithm_configuration["model"] = model_configuration
288
+
289
+ # call the parent init using an AlgorithmModel instance
290
+ super().__init__(AlgorithmModel(**algorithm_configuration))
291
+
292
+ # TODO add load_from_checkpoint wrapper
@@ -0,0 +1,390 @@
1
+ from pathlib import Path
2
+ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as L
6
+ from albumentations import Compose
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.dataloader import default_collate
9
+
10
+ from careamics.config import InferenceModel
11
+ from careamics.config.support import SupportedData
12
+ from careamics.config.tile_information import TileInformation
13
+ from careamics.dataset.dataset_utils import (
14
+ get_read_func,
15
+ list_files,
16
+ )
17
+ from careamics.dataset.in_memory_dataset import (
18
+ InMemoryPredictionDataset,
19
+ )
20
+ from careamics.dataset.iterable_dataset import (
21
+ IterablePredictionDataset,
22
+ )
23
+ from careamics.utils import get_logger
24
+
25
+ PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
31
+ """
32
+ Collate tiles received from CAREamics prediction dataloader.
33
+
34
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
35
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
36
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
37
+ stitch coordinates.
38
+
39
+ Parameters
40
+ ----------
41
+ batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
42
+ Batch of tiles.
43
+
44
+ Returns
45
+ -------
46
+ Any
47
+ Collated batch.
48
+ """
49
+ first_tile_info: TileInformation = batch[0][1]
50
+ # if not tiled, then return arrays
51
+ if not first_tile_info.tiled:
52
+ arrays, _ = zip(*batch)
53
+
54
+ return default_collate(arrays)
55
+ # else we explicit the last_tile flag and coordinates
56
+ else:
57
+ new_batch = [
58
+ (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
59
+ for tile, t in batch
60
+ ]
61
+
62
+ return default_collate(new_batch)
63
+
64
+
65
+ class CAREamicsClay(L.LightningDataModule):
66
+ """
67
+ LightningDataModule for prediction dataset.
68
+
69
+ The data module can be used with Path, str or numpy arrays. The data can be either
70
+ a folder containing images or a single file.
71
+
72
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
73
+ and provide a function that returns a numpy array from a path as
74
+ `read_source_func` parameter. The function will receive a Path object and
75
+ an axies string as arguments, the axes being derived from the `data_config`.
76
+
77
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
78
+ "*.czi") to filter the files extension using `extension_filter`.
79
+
80
+ Parameters
81
+ ----------
82
+ prediction_config : InferenceModel
83
+ Pydantic model for CAREamics prediction configuration.
84
+ pred_data : Union[Path, str, np.ndarray]
85
+ Prediction data, can be a path to a folder, a file or a numpy array.
86
+ read_source_func : Optional[Callable], optional
87
+ Function to read custom types, by default None.
88
+ extension_filter : str, optional
89
+ Filter to filter file extensions for custom types, by default "".
90
+ dataloader_params : dict, optional
91
+ Dataloader parameters, by default {}.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ prediction_config: InferenceModel,
97
+ pred_data: Union[Path, str, np.ndarray],
98
+ read_source_func: Optional[Callable] = None,
99
+ extension_filter: str = "",
100
+ dataloader_params: Optional[dict] = None,
101
+ ) -> None:
102
+ """
103
+ Constructor.
104
+
105
+ The data module can be used with Path, str or numpy arrays. The data can be
106
+ either a folder containing images or a single file.
107
+
108
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
109
+ and provide a function that returns a numpy array from a path as
110
+ `read_source_func` parameter. The function will receive a Path object and
111
+ an axies string as arguments, the axes being derived from the `data_config`.
112
+
113
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
114
+ "*.czi") to filter the files extension using `extension_filter`.
115
+
116
+ Parameters
117
+ ----------
118
+ prediction_config : InferenceModel
119
+ Pydantic model for CAREamics prediction configuration.
120
+ pred_data : Union[Path, str, np.ndarray]
121
+ Prediction data, can be a path to a folder, a file or a numpy array.
122
+ read_source_func : Optional[Callable], optional
123
+ Function to read custom types, by default None.
124
+ extension_filter : str, optional
125
+ Filter to filter file extensions for custom types, by default "".
126
+ dataloader_params : dict, optional
127
+ Dataloader parameters, by default {}.
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If the data type is `custom` and no `read_source_func` is provided.
133
+ ValueError
134
+ If the data type is `array` and the input is not a numpy array.
135
+ ValueError
136
+ If the data type is `tiff` and the input is neither a Path nor a str.
137
+ """
138
+ if dataloader_params is None:
139
+ dataloader_params = {}
140
+ if dataloader_params is None:
141
+ dataloader_params = {}
142
+ super().__init__()
143
+
144
+ # check that a read source function is provided for custom types
145
+ if (
146
+ prediction_config.data_type == SupportedData.CUSTOM
147
+ and read_source_func is None
148
+ ):
149
+ raise ValueError(
150
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
151
+ f"specifying a `read_source_func`."
152
+ )
153
+
154
+ # and that arrays are passed, if array type specified
155
+ elif prediction_config.data_type == SupportedData.ARRAY and not isinstance(
156
+ pred_data, np.ndarray
157
+ ):
158
+ raise ValueError(
159
+ f"Expected array input (see configuration.data.data_type), but got "
160
+ f"{type(pred_data)} instead."
161
+ )
162
+
163
+ # and that Path or str are passed, if tiff file type specified
164
+ elif prediction_config.data_type == SupportedData.TIFF and not (
165
+ isinstance(pred_data, Path) or isinstance(pred_data, str)
166
+ ):
167
+ raise ValueError(
168
+ f"Expected Path or str input (see configuration.data.data_type), "
169
+ f"but got {type(pred_data)} instead."
170
+ )
171
+
172
+ # configuration data
173
+ self.prediction_config = prediction_config
174
+ self.data_type = prediction_config.data_type
175
+ self.batch_size = prediction_config.batch_size
176
+ self.dataloader_params = dataloader_params
177
+
178
+ self.pred_data = pred_data
179
+ self.tile_size = prediction_config.tile_size
180
+ self.tile_overlap = prediction_config.tile_overlap
181
+
182
+ # read source function
183
+ if prediction_config.data_type == SupportedData.CUSTOM:
184
+ # mypy check
185
+ assert read_source_func is not None
186
+
187
+ self.read_source_func: Callable = read_source_func
188
+ elif prediction_config.data_type != SupportedData.ARRAY:
189
+ self.read_source_func = get_read_func(prediction_config.data_type)
190
+
191
+ self.extension_filter = extension_filter
192
+
193
+ def prepare_data(self) -> None:
194
+ """Hook used to prepare the data before calling `setup`."""
195
+ # if the data is a Path or a str
196
+ if not isinstance(self.pred_data, np.ndarray):
197
+ self.pred_files = list_files(
198
+ self.pred_data, self.data_type, self.extension_filter
199
+ )
200
+
201
+ def setup(self, stage: Optional[str] = None) -> None:
202
+ """
203
+ Hook called at the beginning of predict.
204
+
205
+ Parameters
206
+ ----------
207
+ stage : Optional[str], optional
208
+ Stage, by default None.
209
+ """
210
+ # if numpy array
211
+ if self.data_type == SupportedData.ARRAY:
212
+ # prediction dataset
213
+ self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
214
+ prediction_config=self.prediction_config,
215
+ inputs=self.pred_data,
216
+ )
217
+ else:
218
+ self.predict_dataset = IterablePredictionDataset(
219
+ prediction_config=self.prediction_config,
220
+ src_files=self.pred_files,
221
+ read_source_func=self.read_source_func,
222
+ )
223
+
224
+ def predict_dataloader(self) -> DataLoader:
225
+ """
226
+ Create a dataloader for prediction.
227
+
228
+ Returns
229
+ -------
230
+ DataLoader
231
+ Prediction dataloader.
232
+ """
233
+ return DataLoader(
234
+ self.predict_dataset,
235
+ batch_size=self.batch_size,
236
+ collate_fn=_collate_tiles,
237
+ **self.dataloader_params,
238
+ ) # TODO check workers are used
239
+
240
+
241
+ class CAREamicsPredictDataModule(CAREamicsClay):
242
+ """
243
+ LightningDataModule wrapper of an inference dataset.
244
+
245
+ Since the lightning datamodule has no access to the model, make sure that the
246
+ parameters passed to the datamodule are consistent with the model's requirements
247
+ and are coherent.
248
+
249
+ The data module can be used with Path, str or numpy arrays. To use array data, set
250
+ `data_type` to `array` and pass a numpy array to `train_data`.
251
+
252
+ The default transformations applied to the images are defined in
253
+ `careamics.config.inference_model`. To use different transformations, pass a list
254
+ of transforms or an albumentation `Compose` as `transforms` parameter. See examples
255
+ for more details.
256
+
257
+ The `mean` and `std` parameters are only used if Normalization is defined either
258
+ in the default transformations or in the `transforms` parameter, but not with
259
+ a `Compose` object. If you pass a `Normalization` transform in a list as
260
+ `transforms`, then the mean and std parameters will be overwritten by those passed
261
+ to this method.
262
+
263
+ By default, CAREamics only supports types defined in
264
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
265
+ `data_type` to `custom` and provide a function that returns a numpy array from a
266
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
267
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
268
+
269
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
270
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
271
+ parameter.
272
+
273
+ Parameters
274
+ ----------
275
+ pred_data : Union[str, Path, np.ndarray]
276
+ Prediction data.
277
+ data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
278
+ Data type, see `SupportedData` for available options.
279
+ mean : float
280
+ Mean value for normalization, only used if Normalization is defined in the
281
+ transforms.
282
+ std : float
283
+ Standard deviation value for normalization, only used if Normalization is
284
+ defined in the transform.
285
+ tile_size : Tuple[int, ...]
286
+ Tile size, 2D or 3D tile size.
287
+ tile_overlap : Tuple[int, ...]
288
+ Tile overlap, 2D or 3D tile overlap.
289
+ axes : str
290
+ Axes of the data, choosen amongst SCZYX.
291
+ batch_size : int
292
+ Batch size.
293
+ tta_transforms : bool, optional
294
+ Use test time augmentation, by default True.
295
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
296
+ List of transforms to apply to prediction patches. If None, default
297
+ transforms are applied.
298
+ read_source_func : Optional[Callable], optional
299
+ Function to read the source data, used if `data_type` is `custom`, by
300
+ default None.
301
+ extension_filter : str, optional
302
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
303
+ dataloader_params : dict, optional
304
+ Pytorch dataloader parameters, by default {}.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ pred_data: Union[str, Path, np.ndarray],
310
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
311
+ mean: float,
312
+ std: float,
313
+ tile_size: Optional[Tuple[int, ...]] = None,
314
+ tile_overlap: Optional[Tuple[int, ...]] = None,
315
+ axes: str = "YX",
316
+ batch_size: int = 1,
317
+ tta_transforms: bool = True,
318
+ transforms: Optional[Union[List, Compose]] = None,
319
+ read_source_func: Optional[Callable] = None,
320
+ extension_filter: str = "",
321
+ dataloader_params: Optional[dict] = None,
322
+ ) -> None:
323
+ """
324
+ Constructor.
325
+
326
+ Parameters
327
+ ----------
328
+ pred_data : Union[str, Path, np.ndarray]
329
+ Prediction data.
330
+ data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
331
+ Data type, see `SupportedData` for available options.
332
+ tile_size : List[int]
333
+ Tile size, 2D or 3D tile size.
334
+ tile_overlap : List[int]
335
+ Tile overlap, 2D or 3D tile overlap.
336
+ axes : str
337
+ Axes of the data, choosen amongst SCZYX.
338
+ batch_size : int
339
+ Batch size.
340
+ tta_transforms : bool, optional
341
+ Use test time augmentation, by default True.
342
+ mean : Optional[float], optional
343
+ Mean value for normalization, only used if Normalization is defined, by
344
+ default None.
345
+ std : Optional[float], optional
346
+ Standard deviation value for normalization, only used if Normalization is
347
+ defined, by default None.
348
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
349
+ List of transforms to apply to prediction patches. If None, default
350
+ transforms are applied.
351
+ read_source_func : Optional[Callable], optional
352
+ Function to read the source data, used if `data_type` is `custom`, by
353
+ default None.
354
+ extension_filter : str, optional
355
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
356
+ dataloader_params : dict, optional
357
+ Pytorch dataloader parameters, by default {}.
358
+ """
359
+ if dataloader_params is None:
360
+ dataloader_params = {}
361
+ prediction_dict = {
362
+ "data_type": data_type,
363
+ "tile_size": tile_size,
364
+ "tile_overlap": tile_overlap,
365
+ "axes": axes,
366
+ "mean": mean,
367
+ "std": std,
368
+ "tta": tta_transforms,
369
+ "batch_size": batch_size,
370
+ }
371
+
372
+ # if transforms are passed (otherwise it will use the default ones)
373
+ if transforms is not None:
374
+ prediction_dict["transforms"] = transforms
375
+
376
+ # validate configuration
377
+ self.prediction_config = InferenceModel(**prediction_dict)
378
+
379
+ # sanity check on the dataloader parameters
380
+ if "batch_size" in dataloader_params:
381
+ # remove it
382
+ del dataloader_params["batch_size"]
383
+
384
+ super().__init__(
385
+ prediction_config=self.prediction_config,
386
+ pred_data=pred_data,
387
+ read_source_func=read_source_func,
388
+ extension_filter=extension_filter,
389
+ dataloader_params=dataloader_params,
390
+ )