careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from careamics.config import AlgorithmModel
|
|
7
|
+
from careamics.config.support import (
|
|
8
|
+
SupportedAlgorithm,
|
|
9
|
+
SupportedArchitecture,
|
|
10
|
+
SupportedLoss,
|
|
11
|
+
SupportedOptimizer,
|
|
12
|
+
SupportedScheduler,
|
|
13
|
+
)
|
|
14
|
+
from careamics.losses import loss_factory
|
|
15
|
+
from careamics.models.model_factory import model_factory
|
|
16
|
+
from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
17
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CAREamicsKiln(L.LightningModule):
|
|
21
|
+
"""
|
|
22
|
+
CAREamics Lightning module.
|
|
23
|
+
|
|
24
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
25
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
26
|
+
|
|
27
|
+
Attributes
|
|
28
|
+
----------
|
|
29
|
+
model : nn.Module
|
|
30
|
+
PyTorch model.
|
|
31
|
+
loss_func : nn.Module
|
|
32
|
+
Loss function.
|
|
33
|
+
optimizer_name : str
|
|
34
|
+
Optimizer name.
|
|
35
|
+
optimizer_params : dict
|
|
36
|
+
Optimizer parameters.
|
|
37
|
+
lr_scheduler_name : str
|
|
38
|
+
Learning rate scheduler name.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None:
|
|
42
|
+
"""
|
|
43
|
+
CAREamics Lightning module.
|
|
44
|
+
|
|
45
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
46
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
51
|
+
Algorithm configuration.
|
|
52
|
+
"""
|
|
53
|
+
super().__init__()
|
|
54
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
55
|
+
if isinstance(algorithm_config, dict):
|
|
56
|
+
algorithm_config = AlgorithmModel(**algorithm_config)
|
|
57
|
+
|
|
58
|
+
# create model and loss function
|
|
59
|
+
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
60
|
+
self.loss_func = loss_factory(algorithm_config.loss)
|
|
61
|
+
|
|
62
|
+
# save optimizer and lr_scheduler names and parameters
|
|
63
|
+
self.optimizer_name = algorithm_config.optimizer.name
|
|
64
|
+
self.optimizer_params = algorithm_config.optimizer.parameters
|
|
65
|
+
self.lr_scheduler_name = algorithm_config.lr_scheduler.name
|
|
66
|
+
self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
|
|
67
|
+
|
|
68
|
+
def forward(self, x: Any) -> Any:
|
|
69
|
+
"""Forward pass.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
x : Any
|
|
74
|
+
Input tensor.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
Any
|
|
79
|
+
Output tensor.
|
|
80
|
+
"""
|
|
81
|
+
return self.model(x)
|
|
82
|
+
|
|
83
|
+
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
84
|
+
"""Training step.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
batch : Tensor
|
|
89
|
+
Input batch.
|
|
90
|
+
batch_idx : Any
|
|
91
|
+
Batch index.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Any
|
|
96
|
+
Loss value.
|
|
97
|
+
"""
|
|
98
|
+
x, *aux = batch
|
|
99
|
+
out = self.model(x)
|
|
100
|
+
loss = self.loss_func(out, *aux)
|
|
101
|
+
self.log(
|
|
102
|
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
103
|
+
)
|
|
104
|
+
return loss
|
|
105
|
+
|
|
106
|
+
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
107
|
+
"""Validation step.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
batch : Tensor
|
|
112
|
+
Input batch.
|
|
113
|
+
batch_idx : Any
|
|
114
|
+
Batch index.
|
|
115
|
+
"""
|
|
116
|
+
x, *aux = batch
|
|
117
|
+
out = self.model(x)
|
|
118
|
+
val_loss = self.loss_func(out, *aux)
|
|
119
|
+
|
|
120
|
+
# log validation loss
|
|
121
|
+
self.log(
|
|
122
|
+
"val_loss",
|
|
123
|
+
val_loss,
|
|
124
|
+
on_step=False,
|
|
125
|
+
on_epoch=True,
|
|
126
|
+
prog_bar=True,
|
|
127
|
+
logger=True,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
131
|
+
"""Prediction step.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
batch : Tensor
|
|
136
|
+
Input batch.
|
|
137
|
+
batch_idx : Any
|
|
138
|
+
Batch index.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
Any
|
|
143
|
+
Model output.
|
|
144
|
+
"""
|
|
145
|
+
x, *aux = batch
|
|
146
|
+
|
|
147
|
+
# apply test-time augmentation if available
|
|
148
|
+
# TODO: probably wont work with batch size > 1
|
|
149
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
150
|
+
tta = ImageRestorationTTA()
|
|
151
|
+
augmented_batch = tta.forward(batch[0]) # list of augmented tensors
|
|
152
|
+
augmented_output = []
|
|
153
|
+
for augmented in augmented_batch:
|
|
154
|
+
augmented_pred = self.model(augmented)
|
|
155
|
+
augmented_output.append(augmented_pred)
|
|
156
|
+
output = tta.backward(augmented_output)
|
|
157
|
+
else:
|
|
158
|
+
output = self.model(x)
|
|
159
|
+
|
|
160
|
+
# Denormalize the output
|
|
161
|
+
denorm = Denormalize(
|
|
162
|
+
mean=self._trainer.datamodule.predict_dataset.mean,
|
|
163
|
+
std=self._trainer.datamodule.predict_dataset.std,
|
|
164
|
+
)
|
|
165
|
+
denormalized_output = denorm(image=output)["image"]
|
|
166
|
+
|
|
167
|
+
if len(aux) > 0:
|
|
168
|
+
return denormalized_output, aux
|
|
169
|
+
else:
|
|
170
|
+
return denormalized_output
|
|
171
|
+
|
|
172
|
+
def configure_optimizers(self) -> Any:
|
|
173
|
+
"""Configure optimizers and learning rate schedulers.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
Any
|
|
178
|
+
Optimizer and learning rate scheduler.
|
|
179
|
+
"""
|
|
180
|
+
# instantiate optimizer
|
|
181
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
182
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
183
|
+
|
|
184
|
+
# and scheduler
|
|
185
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
186
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
187
|
+
|
|
188
|
+
return {
|
|
189
|
+
"optimizer": optimizer,
|
|
190
|
+
"lr_scheduler": scheduler,
|
|
191
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class CAREamicsModule(CAREamicsKiln):
|
|
196
|
+
"""Class defining the API for CAREamics Lightning layer.
|
|
197
|
+
|
|
198
|
+
This class exposes parameters used to create an AlgorithmModel instance, triggering
|
|
199
|
+
parameters validation.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
algorithm : Union[SupportedAlgorithm, str]
|
|
204
|
+
Algorithm to use for training (see SupportedAlgorithm).
|
|
205
|
+
loss : Union[SupportedLoss, str]
|
|
206
|
+
Loss function to use for training (see SupportedLoss).
|
|
207
|
+
architecture : Union[SupportedArchitecture, str]
|
|
208
|
+
Model architecture to use for training (see SupportedArchitecture).
|
|
209
|
+
model_parameters : dict, optional
|
|
210
|
+
Model parameters to use for training, by default {}. Model parameters are
|
|
211
|
+
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
212
|
+
`careamics.config.architectures`).
|
|
213
|
+
optimizer : Union[SupportedOptimizer, str], optional
|
|
214
|
+
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
215
|
+
optimizer_parameters : dict, optional
|
|
216
|
+
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
217
|
+
default {}.
|
|
218
|
+
lr_scheduler : Union[SupportedScheduler, str], optional
|
|
219
|
+
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
220
|
+
(see SupportedScheduler).
|
|
221
|
+
lr_scheduler_parameters : dict, optional
|
|
222
|
+
Learning rate scheduler parameters to use for training, as defined in
|
|
223
|
+
`torch.optim`, by default {}.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
229
|
+
loss: Union[SupportedLoss, str],
|
|
230
|
+
architecture: Union[SupportedArchitecture, str],
|
|
231
|
+
model_parameters: Optional[dict] = None,
|
|
232
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
233
|
+
optimizer_parameters: Optional[dict] = None,
|
|
234
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
235
|
+
lr_scheduler_parameters: Optional[dict] = None,
|
|
236
|
+
) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Wrapper for the CAREamics model, exposing all algorithm configuration arguments.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
algorithm : Union[SupportedAlgorithm, str]
|
|
243
|
+
Algorithm to use for training (see SupportedAlgorithm).
|
|
244
|
+
loss : Union[SupportedLoss, str]
|
|
245
|
+
Loss function to use for training (see SupportedLoss).
|
|
246
|
+
architecture : Union[SupportedArchitecture, str]
|
|
247
|
+
Model architecture to use for training (see SupportedArchitecture).
|
|
248
|
+
model_parameters : dict, optional
|
|
249
|
+
Model parameters to use for training, by default {}. Model parameters are
|
|
250
|
+
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
251
|
+
`careamics.config.architectures`).
|
|
252
|
+
optimizer : Union[SupportedOptimizer, str], optional
|
|
253
|
+
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
254
|
+
optimizer_parameters : dict, optional
|
|
255
|
+
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
256
|
+
default {}.
|
|
257
|
+
lr_scheduler : Union[SupportedScheduler, str], optional
|
|
258
|
+
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
259
|
+
(see SupportedScheduler).
|
|
260
|
+
lr_scheduler_parameters : dict, optional
|
|
261
|
+
Learning rate scheduler parameters to use for training, as defined in
|
|
262
|
+
`torch.optim`, by default {}.
|
|
263
|
+
"""
|
|
264
|
+
# create a AlgorithmModel compatible dictionary
|
|
265
|
+
if lr_scheduler_parameters is None:
|
|
266
|
+
lr_scheduler_parameters = {}
|
|
267
|
+
if optimizer_parameters is None:
|
|
268
|
+
optimizer_parameters = {}
|
|
269
|
+
if model_parameters is None:
|
|
270
|
+
model_parameters = {}
|
|
271
|
+
algorithm_configuration = {
|
|
272
|
+
"algorithm": algorithm,
|
|
273
|
+
"loss": loss,
|
|
274
|
+
"optimizer": {
|
|
275
|
+
"name": optimizer,
|
|
276
|
+
"parameters": optimizer_parameters,
|
|
277
|
+
},
|
|
278
|
+
"lr_scheduler": {
|
|
279
|
+
"name": lr_scheduler,
|
|
280
|
+
"parameters": lr_scheduler_parameters,
|
|
281
|
+
},
|
|
282
|
+
}
|
|
283
|
+
model_configuration = {"architecture": architecture}
|
|
284
|
+
model_configuration.update(model_parameters)
|
|
285
|
+
|
|
286
|
+
# add model parameters to algorithm configuration
|
|
287
|
+
algorithm_configuration["model"] = model_configuration
|
|
288
|
+
|
|
289
|
+
# call the parent init using an AlgorithmModel instance
|
|
290
|
+
super().__init__(AlgorithmModel(**algorithm_configuration))
|
|
291
|
+
|
|
292
|
+
# TODO add load_from_checkpoint wrapper
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
from albumentations import Compose
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from torch.utils.data.dataloader import default_collate
|
|
9
|
+
|
|
10
|
+
from careamics.config import InferenceModel
|
|
11
|
+
from careamics.config.support import SupportedData
|
|
12
|
+
from careamics.config.tile_information import TileInformation
|
|
13
|
+
from careamics.dataset.dataset_utils import (
|
|
14
|
+
get_read_func,
|
|
15
|
+
list_files,
|
|
16
|
+
)
|
|
17
|
+
from careamics.dataset.in_memory_dataset import (
|
|
18
|
+
InMemoryPredictionDataset,
|
|
19
|
+
)
|
|
20
|
+
from careamics.dataset.iterable_dataset import (
|
|
21
|
+
IterablePredictionDataset,
|
|
22
|
+
)
|
|
23
|
+
from careamics.utils import get_logger
|
|
24
|
+
|
|
25
|
+
PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset]
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
31
|
+
"""
|
|
32
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
33
|
+
|
|
34
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
35
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
36
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
37
|
+
stitch coordinates.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
|
|
42
|
+
Batch of tiles.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
Any
|
|
47
|
+
Collated batch.
|
|
48
|
+
"""
|
|
49
|
+
first_tile_info: TileInformation = batch[0][1]
|
|
50
|
+
# if not tiled, then return arrays
|
|
51
|
+
if not first_tile_info.tiled:
|
|
52
|
+
arrays, _ = zip(*batch)
|
|
53
|
+
|
|
54
|
+
return default_collate(arrays)
|
|
55
|
+
# else we explicit the last_tile flag and coordinates
|
|
56
|
+
else:
|
|
57
|
+
new_batch = [
|
|
58
|
+
(tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
|
|
59
|
+
for tile, t in batch
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
return default_collate(new_batch)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CAREamicsClay(L.LightningDataModule):
|
|
66
|
+
"""
|
|
67
|
+
LightningDataModule for prediction dataset.
|
|
68
|
+
|
|
69
|
+
The data module can be used with Path, str or numpy arrays. The data can be either
|
|
70
|
+
a folder containing images or a single file.
|
|
71
|
+
|
|
72
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
73
|
+
and provide a function that returns a numpy array from a path as
|
|
74
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
75
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
76
|
+
|
|
77
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
78
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
prediction_config : InferenceModel
|
|
83
|
+
Pydantic model for CAREamics prediction configuration.
|
|
84
|
+
pred_data : Union[Path, str, np.ndarray]
|
|
85
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
86
|
+
read_source_func : Optional[Callable], optional
|
|
87
|
+
Function to read custom types, by default None.
|
|
88
|
+
extension_filter : str, optional
|
|
89
|
+
Filter to filter file extensions for custom types, by default "".
|
|
90
|
+
dataloader_params : dict, optional
|
|
91
|
+
Dataloader parameters, by default {}.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
prediction_config: InferenceModel,
|
|
97
|
+
pred_data: Union[Path, str, np.ndarray],
|
|
98
|
+
read_source_func: Optional[Callable] = None,
|
|
99
|
+
extension_filter: str = "",
|
|
100
|
+
dataloader_params: Optional[dict] = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Constructor.
|
|
104
|
+
|
|
105
|
+
The data module can be used with Path, str or numpy arrays. The data can be
|
|
106
|
+
either a folder containing images or a single file.
|
|
107
|
+
|
|
108
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
109
|
+
and provide a function that returns a numpy array from a path as
|
|
110
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
111
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
112
|
+
|
|
113
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
114
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
prediction_config : InferenceModel
|
|
119
|
+
Pydantic model for CAREamics prediction configuration.
|
|
120
|
+
pred_data : Union[Path, str, np.ndarray]
|
|
121
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
122
|
+
read_source_func : Optional[Callable], optional
|
|
123
|
+
Function to read custom types, by default None.
|
|
124
|
+
extension_filter : str, optional
|
|
125
|
+
Filter to filter file extensions for custom types, by default "".
|
|
126
|
+
dataloader_params : dict, optional
|
|
127
|
+
Dataloader parameters, by default {}.
|
|
128
|
+
|
|
129
|
+
Raises
|
|
130
|
+
------
|
|
131
|
+
ValueError
|
|
132
|
+
If the data type is `custom` and no `read_source_func` is provided.
|
|
133
|
+
ValueError
|
|
134
|
+
If the data type is `array` and the input is not a numpy array.
|
|
135
|
+
ValueError
|
|
136
|
+
If the data type is `tiff` and the input is neither a Path nor a str.
|
|
137
|
+
"""
|
|
138
|
+
if dataloader_params is None:
|
|
139
|
+
dataloader_params = {}
|
|
140
|
+
if dataloader_params is None:
|
|
141
|
+
dataloader_params = {}
|
|
142
|
+
super().__init__()
|
|
143
|
+
|
|
144
|
+
# check that a read source function is provided for custom types
|
|
145
|
+
if (
|
|
146
|
+
prediction_config.data_type == SupportedData.CUSTOM
|
|
147
|
+
and read_source_func is None
|
|
148
|
+
):
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
151
|
+
f"specifying a `read_source_func`."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# and that arrays are passed, if array type specified
|
|
155
|
+
elif prediction_config.data_type == SupportedData.ARRAY and not isinstance(
|
|
156
|
+
pred_data, np.ndarray
|
|
157
|
+
):
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Expected array input (see configuration.data.data_type), but got "
|
|
160
|
+
f"{type(pred_data)} instead."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# and that Path or str are passed, if tiff file type specified
|
|
164
|
+
elif prediction_config.data_type == SupportedData.TIFF and not (
|
|
165
|
+
isinstance(pred_data, Path) or isinstance(pred_data, str)
|
|
166
|
+
):
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Expected Path or str input (see configuration.data.data_type), "
|
|
169
|
+
f"but got {type(pred_data)} instead."
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# configuration data
|
|
173
|
+
self.prediction_config = prediction_config
|
|
174
|
+
self.data_type = prediction_config.data_type
|
|
175
|
+
self.batch_size = prediction_config.batch_size
|
|
176
|
+
self.dataloader_params = dataloader_params
|
|
177
|
+
|
|
178
|
+
self.pred_data = pred_data
|
|
179
|
+
self.tile_size = prediction_config.tile_size
|
|
180
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
181
|
+
|
|
182
|
+
# read source function
|
|
183
|
+
if prediction_config.data_type == SupportedData.CUSTOM:
|
|
184
|
+
# mypy check
|
|
185
|
+
assert read_source_func is not None
|
|
186
|
+
|
|
187
|
+
self.read_source_func: Callable = read_source_func
|
|
188
|
+
elif prediction_config.data_type != SupportedData.ARRAY:
|
|
189
|
+
self.read_source_func = get_read_func(prediction_config.data_type)
|
|
190
|
+
|
|
191
|
+
self.extension_filter = extension_filter
|
|
192
|
+
|
|
193
|
+
def prepare_data(self) -> None:
|
|
194
|
+
"""Hook used to prepare the data before calling `setup`."""
|
|
195
|
+
# if the data is a Path or a str
|
|
196
|
+
if not isinstance(self.pred_data, np.ndarray):
|
|
197
|
+
self.pred_files = list_files(
|
|
198
|
+
self.pred_data, self.data_type, self.extension_filter
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
202
|
+
"""
|
|
203
|
+
Hook called at the beginning of predict.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
stage : Optional[str], optional
|
|
208
|
+
Stage, by default None.
|
|
209
|
+
"""
|
|
210
|
+
# if numpy array
|
|
211
|
+
if self.data_type == SupportedData.ARRAY:
|
|
212
|
+
# prediction dataset
|
|
213
|
+
self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset(
|
|
214
|
+
prediction_config=self.prediction_config,
|
|
215
|
+
inputs=self.pred_data,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
self.predict_dataset = IterablePredictionDataset(
|
|
219
|
+
prediction_config=self.prediction_config,
|
|
220
|
+
src_files=self.pred_files,
|
|
221
|
+
read_source_func=self.read_source_func,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def predict_dataloader(self) -> DataLoader:
|
|
225
|
+
"""
|
|
226
|
+
Create a dataloader for prediction.
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
DataLoader
|
|
231
|
+
Prediction dataloader.
|
|
232
|
+
"""
|
|
233
|
+
return DataLoader(
|
|
234
|
+
self.predict_dataset,
|
|
235
|
+
batch_size=self.batch_size,
|
|
236
|
+
collate_fn=_collate_tiles,
|
|
237
|
+
**self.dataloader_params,
|
|
238
|
+
) # TODO check workers are used
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class CAREamicsPredictDataModule(CAREamicsClay):
|
|
242
|
+
"""
|
|
243
|
+
LightningDataModule wrapper of an inference dataset.
|
|
244
|
+
|
|
245
|
+
Since the lightning datamodule has no access to the model, make sure that the
|
|
246
|
+
parameters passed to the datamodule are consistent with the model's requirements
|
|
247
|
+
and are coherent.
|
|
248
|
+
|
|
249
|
+
The data module can be used with Path, str or numpy arrays. To use array data, set
|
|
250
|
+
`data_type` to `array` and pass a numpy array to `train_data`.
|
|
251
|
+
|
|
252
|
+
The default transformations applied to the images are defined in
|
|
253
|
+
`careamics.config.inference_model`. To use different transformations, pass a list
|
|
254
|
+
of transforms or an albumentation `Compose` as `transforms` parameter. See examples
|
|
255
|
+
for more details.
|
|
256
|
+
|
|
257
|
+
The `mean` and `std` parameters are only used if Normalization is defined either
|
|
258
|
+
in the default transformations or in the `transforms` parameter, but not with
|
|
259
|
+
a `Compose` object. If you pass a `Normalization` transform in a list as
|
|
260
|
+
`transforms`, then the mean and std parameters will be overwritten by those passed
|
|
261
|
+
to this method.
|
|
262
|
+
|
|
263
|
+
By default, CAREamics only supports types defined in
|
|
264
|
+
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
265
|
+
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
266
|
+
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
|
|
267
|
+
(e.g. "*.jpeg") to filter the files extension using `extension_filter`.
|
|
268
|
+
|
|
269
|
+
In `dataloader_params`, you can pass any parameter accepted by PyTorch
|
|
270
|
+
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
271
|
+
parameter.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
pred_data : Union[str, Path, np.ndarray]
|
|
276
|
+
Prediction data.
|
|
277
|
+
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
278
|
+
Data type, see `SupportedData` for available options.
|
|
279
|
+
mean : float
|
|
280
|
+
Mean value for normalization, only used if Normalization is defined in the
|
|
281
|
+
transforms.
|
|
282
|
+
std : float
|
|
283
|
+
Standard deviation value for normalization, only used if Normalization is
|
|
284
|
+
defined in the transform.
|
|
285
|
+
tile_size : Tuple[int, ...]
|
|
286
|
+
Tile size, 2D or 3D tile size.
|
|
287
|
+
tile_overlap : Tuple[int, ...]
|
|
288
|
+
Tile overlap, 2D or 3D tile overlap.
|
|
289
|
+
axes : str
|
|
290
|
+
Axes of the data, choosen amongst SCZYX.
|
|
291
|
+
batch_size : int
|
|
292
|
+
Batch size.
|
|
293
|
+
tta_transforms : bool, optional
|
|
294
|
+
Use test time augmentation, by default True.
|
|
295
|
+
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
296
|
+
List of transforms to apply to prediction patches. If None, default
|
|
297
|
+
transforms are applied.
|
|
298
|
+
read_source_func : Optional[Callable], optional
|
|
299
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
300
|
+
default None.
|
|
301
|
+
extension_filter : str, optional
|
|
302
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
303
|
+
dataloader_params : dict, optional
|
|
304
|
+
Pytorch dataloader parameters, by default {}.
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
pred_data: Union[str, Path, np.ndarray],
|
|
310
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
311
|
+
mean: float,
|
|
312
|
+
std: float,
|
|
313
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
314
|
+
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
315
|
+
axes: str = "YX",
|
|
316
|
+
batch_size: int = 1,
|
|
317
|
+
tta_transforms: bool = True,
|
|
318
|
+
transforms: Optional[Union[List, Compose]] = None,
|
|
319
|
+
read_source_func: Optional[Callable] = None,
|
|
320
|
+
extension_filter: str = "",
|
|
321
|
+
dataloader_params: Optional[dict] = None,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""
|
|
324
|
+
Constructor.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
pred_data : Union[str, Path, np.ndarray]
|
|
329
|
+
Prediction data.
|
|
330
|
+
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
331
|
+
Data type, see `SupportedData` for available options.
|
|
332
|
+
tile_size : List[int]
|
|
333
|
+
Tile size, 2D or 3D tile size.
|
|
334
|
+
tile_overlap : List[int]
|
|
335
|
+
Tile overlap, 2D or 3D tile overlap.
|
|
336
|
+
axes : str
|
|
337
|
+
Axes of the data, choosen amongst SCZYX.
|
|
338
|
+
batch_size : int
|
|
339
|
+
Batch size.
|
|
340
|
+
tta_transforms : bool, optional
|
|
341
|
+
Use test time augmentation, by default True.
|
|
342
|
+
mean : Optional[float], optional
|
|
343
|
+
Mean value for normalization, only used if Normalization is defined, by
|
|
344
|
+
default None.
|
|
345
|
+
std : Optional[float], optional
|
|
346
|
+
Standard deviation value for normalization, only used if Normalization is
|
|
347
|
+
defined, by default None.
|
|
348
|
+
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
349
|
+
List of transforms to apply to prediction patches. If None, default
|
|
350
|
+
transforms are applied.
|
|
351
|
+
read_source_func : Optional[Callable], optional
|
|
352
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
353
|
+
default None.
|
|
354
|
+
extension_filter : str, optional
|
|
355
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
356
|
+
dataloader_params : dict, optional
|
|
357
|
+
Pytorch dataloader parameters, by default {}.
|
|
358
|
+
"""
|
|
359
|
+
if dataloader_params is None:
|
|
360
|
+
dataloader_params = {}
|
|
361
|
+
prediction_dict = {
|
|
362
|
+
"data_type": data_type,
|
|
363
|
+
"tile_size": tile_size,
|
|
364
|
+
"tile_overlap": tile_overlap,
|
|
365
|
+
"axes": axes,
|
|
366
|
+
"mean": mean,
|
|
367
|
+
"std": std,
|
|
368
|
+
"tta": tta_transforms,
|
|
369
|
+
"batch_size": batch_size,
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
# if transforms are passed (otherwise it will use the default ones)
|
|
373
|
+
if transforms is not None:
|
|
374
|
+
prediction_dict["transforms"] = transforms
|
|
375
|
+
|
|
376
|
+
# validate configuration
|
|
377
|
+
self.prediction_config = InferenceModel(**prediction_dict)
|
|
378
|
+
|
|
379
|
+
# sanity check on the dataloader parameters
|
|
380
|
+
if "batch_size" in dataloader_params:
|
|
381
|
+
# remove it
|
|
382
|
+
del dataloader_params["batch_size"]
|
|
383
|
+
|
|
384
|
+
super().__init__(
|
|
385
|
+
prediction_config=self.prediction_config,
|
|
386
|
+
pred_data=pred_data,
|
|
387
|
+
read_source_func=read_source_func,
|
|
388
|
+
extension_filter=extension_filter,
|
|
389
|
+
dataloader_params=dataloader_params,
|
|
390
|
+
)
|