careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
95
95
|
Batch size.
|
|
96
96
|
use_in_memory : bool
|
|
97
97
|
Whether to use in memory dataset if possible.
|
|
98
|
-
train_data : Union[Path,
|
|
98
|
+
train_data : Union[Path, np.ndarray]
|
|
99
99
|
Training data.
|
|
100
|
-
val_data : Optional[Union[Path,
|
|
100
|
+
val_data : Optional[Union[Path, np.ndarray]]
|
|
101
101
|
Validation data.
|
|
102
|
-
train_data_target : Optional[Union[Path,
|
|
102
|
+
train_data_target : Optional[Union[Path, np.ndarray]]
|
|
103
103
|
Training target data.
|
|
104
|
-
val_data_target : Optional[Union[Path,
|
|
104
|
+
val_data_target : Optional[Union[Path, np.ndarray]]
|
|
105
105
|
Validation target data.
|
|
106
106
|
val_percentage : float
|
|
107
107
|
Percentage of the training data to use for validation, if no validation data is
|
|
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
217
217
|
)
|
|
218
218
|
|
|
219
219
|
# configuration
|
|
220
|
-
self.data_config = data_config
|
|
221
|
-
self.data_type = data_config.data_type
|
|
222
|
-
self.batch_size = data_config.batch_size
|
|
223
|
-
self.use_in_memory = use_in_memory
|
|
220
|
+
self.data_config: DataConfig = data_config
|
|
221
|
+
self.data_type: str = data_config.data_type
|
|
222
|
+
self.batch_size: int = data_config.batch_size
|
|
223
|
+
self.use_in_memory: bool = use_in_memory
|
|
224
|
+
|
|
225
|
+
# data: make data Path or np.ndarray, use type annotations for mypy
|
|
226
|
+
self.train_data: Union[Path, np.ndarray] = (
|
|
227
|
+
Path(train_data) if isinstance(train_data, str) else train_data
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
self.val_data: Union[Path, np.ndarray] = (
|
|
231
|
+
Path(val_data) if isinstance(val_data, str) else val_data
|
|
232
|
+
)
|
|
224
233
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
234
|
+
self.train_data_target: Union[Path, np.ndarray] = (
|
|
235
|
+
Path(train_data_target)
|
|
236
|
+
if isinstance(train_data_target, str)
|
|
237
|
+
else train_data_target
|
|
238
|
+
)
|
|
228
239
|
|
|
229
|
-
self.
|
|
230
|
-
|
|
240
|
+
self.val_data_target: Union[Path, np.ndarray] = (
|
|
241
|
+
Path(val_data_target)
|
|
242
|
+
if isinstance(val_data_target, str)
|
|
243
|
+
else val_data_target
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# validation split
|
|
231
247
|
self.val_percentage = val_percentage
|
|
232
248
|
self.val_minimum_split = val_minimum_split
|
|
233
249
|
|
|
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
241
257
|
elif data_config.data_type != SupportedData.ARRAY:
|
|
242
258
|
self.read_source_func = get_read_func(data_config.data_type)
|
|
243
259
|
|
|
244
|
-
self.extension_filter = extension_filter
|
|
260
|
+
self.extension_filter: str = extension_filter
|
|
245
261
|
|
|
246
262
|
# Pytorch dataloader parameters
|
|
247
|
-
self.dataloader_params = (
|
|
263
|
+
self.dataloader_params: Dict[str, Any] = (
|
|
248
264
|
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
249
265
|
)
|
|
250
266
|
|
|
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
309
325
|
"""
|
|
310
326
|
# if numpy array
|
|
311
327
|
if self.data_type == SupportedData.ARRAY:
|
|
328
|
+
# mypy checks
|
|
329
|
+
assert isinstance(self.train_data, np.ndarray)
|
|
330
|
+
if self.train_data_target is not None:
|
|
331
|
+
assert isinstance(self.train_data_target, np.ndarray)
|
|
332
|
+
|
|
312
333
|
# train dataset
|
|
313
334
|
self.train_dataset: DatasetType = InMemoryDataset(
|
|
314
335
|
data_config=self.data_config,
|
|
315
336
|
inputs=self.train_data,
|
|
316
|
-
|
|
337
|
+
input_target=self.train_data_target,
|
|
317
338
|
)
|
|
318
339
|
|
|
319
340
|
# validation dataset
|
|
320
341
|
if self.val_data is not None:
|
|
342
|
+
# mypy checks
|
|
343
|
+
assert isinstance(self.val_data, np.ndarray)
|
|
344
|
+
if self.val_data_target is not None:
|
|
345
|
+
assert isinstance(self.val_data_target, np.ndarray)
|
|
346
|
+
|
|
321
347
|
# create its own dataset
|
|
322
348
|
self.val_dataset: DatasetType = InMemoryDataset(
|
|
323
349
|
data_config=self.data_config,
|
|
324
350
|
inputs=self.val_data,
|
|
325
|
-
|
|
351
|
+
input_target=self.val_data_target,
|
|
326
352
|
)
|
|
327
353
|
else:
|
|
328
354
|
# extract validation from the training patches
|
|
@@ -341,7 +367,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
341
367
|
self.train_dataset = InMemoryDataset(
|
|
342
368
|
data_config=self.data_config,
|
|
343
369
|
inputs=self.train_files,
|
|
344
|
-
|
|
370
|
+
input_target=(
|
|
345
371
|
self.train_target_files if self.train_data_target else None
|
|
346
372
|
),
|
|
347
373
|
read_source_func=self.read_source_func,
|
|
@@ -352,7 +378,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
352
378
|
self.val_dataset = InMemoryDataset(
|
|
353
379
|
data_config=self.data_config,
|
|
354
380
|
inputs=self.val_files,
|
|
355
|
-
|
|
381
|
+
input_target=(
|
|
356
382
|
self.val_target_files if self.val_data_target else None
|
|
357
383
|
),
|
|
358
384
|
read_source_func=self.read_source_func,
|
|
@@ -557,12 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
557
583
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
558
584
|
>>> my_transforms = [
|
|
559
585
|
... {
|
|
560
|
-
... "name": SupportedTransform.
|
|
561
|
-
... "mean": 0,
|
|
562
|
-
... "std": 1,
|
|
563
|
-
... },
|
|
564
|
-
... {
|
|
565
|
-
... "name": SupportedTransform.N2V_MANIPULATE.value,
|
|
586
|
+
... "name": SupportedTransform.XY_FLIP.value,
|
|
566
587
|
... }
|
|
567
588
|
... ]
|
|
568
589
|
>>> data_module = TrainingDataWrapper(
|
careamics/lightning_module.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Optional, Union
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
|
|
|
24
26
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
25
27
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
26
28
|
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
32
|
+
Algorithm configuration.
|
|
33
|
+
|
|
27
34
|
Attributes
|
|
28
35
|
----------
|
|
29
36
|
model : nn.Module
|
|
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
39
46
|
"""
|
|
40
47
|
|
|
41
48
|
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
42
|
-
"""
|
|
43
|
-
CAREamics Lightning module.
|
|
49
|
+
"""Lightning module for CAREamics.
|
|
44
50
|
|
|
45
51
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
46
52
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
@@ -142,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
|
|
|
142
148
|
Any
|
|
143
149
|
Model output.
|
|
144
150
|
"""
|
|
145
|
-
|
|
151
|
+
if self._trainer.datamodule.tiled:
|
|
152
|
+
x, *aux = batch
|
|
153
|
+
else:
|
|
154
|
+
x = batch
|
|
155
|
+
aux = []
|
|
146
156
|
|
|
147
157
|
# apply test-time augmentation if available
|
|
148
158
|
# TODO: probably wont work with batch size > 1
|
|
149
159
|
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
150
160
|
tta = ImageRestorationTTA()
|
|
151
|
-
augmented_batch = tta.forward(
|
|
161
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
152
162
|
augmented_output = []
|
|
153
163
|
for augmented in augmented_batch:
|
|
154
164
|
augmented_pred = self.model(augmented)
|
|
@@ -159,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
|
|
|
159
169
|
|
|
160
170
|
# Denormalize the output
|
|
161
171
|
denorm = Denormalize(
|
|
162
|
-
|
|
163
|
-
|
|
172
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
173
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
164
174
|
)
|
|
165
|
-
denormalized_output
|
|
175
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
166
176
|
|
|
167
|
-
if len(aux) > 0:
|
|
168
|
-
return denormalized_output, aux
|
|
177
|
+
if len(aux) > 0: # aux can be tiling information
|
|
178
|
+
return denormalized_output, *aux
|
|
169
179
|
else:
|
|
170
180
|
return denormalized_output
|
|
171
181
|
|
|
@@ -1,68 +1,37 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torch.utils.data.dataloader import default_collate
|
|
10
9
|
|
|
11
10
|
from careamics.config import InferenceConfig
|
|
12
11
|
from careamics.config.support import SupportedData
|
|
13
|
-
from careamics.
|
|
12
|
+
from careamics.dataset import (
|
|
13
|
+
InMemoryPredDataset,
|
|
14
|
+
InMemoryTiledPredDataset,
|
|
15
|
+
IterablePredDataset,
|
|
16
|
+
IterableTiledPredDataset,
|
|
17
|
+
)
|
|
14
18
|
from careamics.dataset.dataset_utils import (
|
|
15
19
|
get_read_func,
|
|
16
20
|
list_files,
|
|
17
21
|
)
|
|
18
|
-
from careamics.dataset.
|
|
19
|
-
InMemoryPredictionDataset,
|
|
20
|
-
)
|
|
21
|
-
from careamics.dataset.iterable_dataset import (
|
|
22
|
-
IterablePredictionDataset,
|
|
23
|
-
)
|
|
22
|
+
from careamics.dataset.tiling.collate_tiles import collate_tiles
|
|
24
23
|
from careamics.utils import get_logger
|
|
25
24
|
|
|
26
|
-
PredictDatasetType = Union[
|
|
25
|
+
PredictDatasetType = Union[
|
|
26
|
+
InMemoryPredDataset,
|
|
27
|
+
InMemoryTiledPredDataset,
|
|
28
|
+
IterablePredDataset,
|
|
29
|
+
IterableTiledPredDataset,
|
|
30
|
+
]
|
|
27
31
|
|
|
28
32
|
logger = get_logger(__name__)
|
|
29
33
|
|
|
30
34
|
|
|
31
|
-
def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
32
|
-
"""
|
|
33
|
-
Collate tiles received from CAREamics prediction dataloader.
|
|
34
|
-
|
|
35
|
-
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
36
|
-
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
37
|
-
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
38
|
-
stitch coordinates.
|
|
39
|
-
|
|
40
|
-
Parameters
|
|
41
|
-
----------
|
|
42
|
-
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
43
|
-
Batch of tiles.
|
|
44
|
-
|
|
45
|
-
Returns
|
|
46
|
-
-------
|
|
47
|
-
Any
|
|
48
|
-
Collated batch.
|
|
49
|
-
"""
|
|
50
|
-
first_tile_info: TileInformation = batch[0][1]
|
|
51
|
-
# if not tiled, then return arrays
|
|
52
|
-
if not first_tile_info.tiled:
|
|
53
|
-
arrays, _ = zip(*batch)
|
|
54
|
-
|
|
55
|
-
return default_collate(arrays)
|
|
56
|
-
# else we explicit the last_tile flag and coordinates
|
|
57
|
-
else:
|
|
58
|
-
new_batch = [
|
|
59
|
-
(tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
|
|
60
|
-
for tile, t in batch
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
return default_collate(new_batch)
|
|
64
|
-
|
|
65
|
-
|
|
66
35
|
class CAREamicsPredictData(L.LightningDataModule):
|
|
67
36
|
"""
|
|
68
37
|
CAREamics Lightning prediction data module.
|
|
@@ -182,6 +151,9 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
182
151
|
self.tile_size = pred_config.tile_size
|
|
183
152
|
self.tile_overlap = pred_config.tile_overlap
|
|
184
153
|
|
|
154
|
+
# check if it is tiled
|
|
155
|
+
self.tiled = self.tile_size is not None and self.tile_overlap is not None
|
|
156
|
+
|
|
185
157
|
# read source function
|
|
186
158
|
if pred_config.data_type == SupportedData.CUSTOM:
|
|
187
159
|
# mypy check
|
|
@@ -212,17 +184,29 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
212
184
|
"""
|
|
213
185
|
# if numpy array
|
|
214
186
|
if self.data_type == SupportedData.ARRAY:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
187
|
+
if self.tiled:
|
|
188
|
+
self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
|
|
189
|
+
prediction_config=self.prediction_config,
|
|
190
|
+
inputs=self.pred_data,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
self.predict_dataset = InMemoryPredDataset(
|
|
194
|
+
prediction_config=self.prediction_config,
|
|
195
|
+
inputs=self.pred_data,
|
|
196
|
+
)
|
|
220
197
|
else:
|
|
221
|
-
self.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
198
|
+
if self.tiled:
|
|
199
|
+
self.predict_dataset = IterableTiledPredDataset(
|
|
200
|
+
prediction_config=self.prediction_config,
|
|
201
|
+
src_files=self.pred_files,
|
|
202
|
+
read_source_func=self.read_source_func,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
self.predict_dataset = IterablePredDataset(
|
|
206
|
+
prediction_config=self.prediction_config,
|
|
207
|
+
src_files=self.pred_files,
|
|
208
|
+
read_source_func=self.read_source_func,
|
|
209
|
+
)
|
|
226
210
|
|
|
227
211
|
def predict_dataloader(self) -> DataLoader:
|
|
228
212
|
"""
|
|
@@ -236,7 +220,7 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
236
220
|
return DataLoader(
|
|
237
221
|
self.predict_dataset,
|
|
238
222
|
batch_size=self.batch_size,
|
|
239
|
-
collate_fn=
|
|
223
|
+
collate_fn=collate_tiles if self.tiled else None,
|
|
240
224
|
**self.dataloader_params,
|
|
241
225
|
) # TODO check workers are used
|
|
242
226
|
|
|
@@ -287,12 +271,10 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
287
271
|
Prediction data.
|
|
288
272
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
289
273
|
Data type, see `SupportedData` for available options.
|
|
290
|
-
|
|
291
|
-
Mean
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
295
|
-
defined in the transform.
|
|
274
|
+
image_means : list of float
|
|
275
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
276
|
+
image_stds : list of float
|
|
277
|
+
Std values for normalization, only used if Normalization is defined.
|
|
296
278
|
tile_size : Tuple[int, ...]
|
|
297
279
|
Tile size, 2D or 3D tile size.
|
|
298
280
|
tile_overlap : Tuple[int, ...]
|
|
@@ -303,9 +285,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
303
285
|
Batch size.
|
|
304
286
|
tta_transforms : bool, optional
|
|
305
287
|
Use test time augmentation, by default True.
|
|
306
|
-
transforms : List, optional
|
|
307
|
-
List of transforms to apply to prediction patches. If None, default
|
|
308
|
-
transforms are applied.
|
|
309
288
|
read_source_func : Optional[Callable], optional
|
|
310
289
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
311
290
|
default None.
|
|
@@ -319,14 +298,13 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
319
298
|
self,
|
|
320
299
|
pred_data: Union[str, Path, np.ndarray],
|
|
321
300
|
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
322
|
-
|
|
323
|
-
|
|
301
|
+
image_means=list[float],
|
|
302
|
+
image_stds=list[float],
|
|
324
303
|
tile_size: Optional[Tuple[int, ...]] = None,
|
|
325
304
|
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
326
305
|
axes: str = "YX",
|
|
327
306
|
batch_size: int = 1,
|
|
328
307
|
tta_transforms: bool = True,
|
|
329
|
-
transforms: Optional[List] = None,
|
|
330
308
|
read_source_func: Optional[Callable] = None,
|
|
331
309
|
extension_filter: str = "",
|
|
332
310
|
dataloader_params: Optional[dict] = None,
|
|
@@ -340,12 +318,10 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
340
318
|
Prediction data.
|
|
341
319
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
342
320
|
Data type, see `SupportedData` for available options.
|
|
343
|
-
|
|
344
|
-
Mean
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
348
|
-
defined in the transform.
|
|
321
|
+
image_means : list of float
|
|
322
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
323
|
+
image_stds : list of float
|
|
324
|
+
Std values for normalization, only used if Normalization is defined.
|
|
349
325
|
tile_size : List[int]
|
|
350
326
|
Tile size, 2D or 3D tile size.
|
|
351
327
|
tile_overlap : List[int]
|
|
@@ -356,9 +332,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
356
332
|
Batch size.
|
|
357
333
|
tta_transforms : bool, optional
|
|
358
334
|
Use test time augmentation, by default True.
|
|
359
|
-
transforms : Optional[List], optional
|
|
360
|
-
List of transforms to apply to prediction patches. If None, default
|
|
361
|
-
transforms are applied.
|
|
362
335
|
read_source_func : Optional[Callable], optional
|
|
363
336
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
364
337
|
default None.
|
|
@@ -369,21 +342,18 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
369
342
|
"""
|
|
370
343
|
if dataloader_params is None:
|
|
371
344
|
dataloader_params = {}
|
|
372
|
-
prediction_dict = {
|
|
345
|
+
prediction_dict: Dict[str, Any] = {
|
|
373
346
|
"data_type": data_type,
|
|
374
347
|
"tile_size": tile_size,
|
|
375
348
|
"tile_overlap": tile_overlap,
|
|
376
349
|
"axes": axes,
|
|
377
|
-
"
|
|
378
|
-
"
|
|
350
|
+
"image_means": image_means,
|
|
351
|
+
"image_stds": image_stds,
|
|
379
352
|
"tta": tta_transforms,
|
|
380
353
|
"batch_size": batch_size,
|
|
354
|
+
"transforms": [],
|
|
381
355
|
}
|
|
382
356
|
|
|
383
|
-
# if transforms are passed (otherwise it will use the default ones)
|
|
384
|
-
if transforms is not None:
|
|
385
|
-
prediction_dict["transforms"] = transforms
|
|
386
|
-
|
|
387
357
|
# validate configuration
|
|
388
358
|
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
389
359
|
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
__all__ = ["loss_factory"]
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
|
|
5
|
+
from .loss_factory import loss_factory
|
careamics/losses/loss_factory.py
CHANGED
careamics/losses/losses.py
CHANGED
|
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
|
|
9
|
-
# TODO if we are only using the DiceLoss, can we just implement it?
|
|
10
|
-
# from segmentation_models_pytorch.losses import DiceLoss
|
|
11
8
|
from torch.nn import L1Loss, MSELoss
|
|
12
9
|
|
|
13
10
|
|
|
14
|
-
def mse_loss(
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
15
12
|
"""
|
|
16
13
|
Mean squared error loss.
|
|
17
14
|
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
source : torch.Tensor
|
|
18
|
+
Source patches.
|
|
19
|
+
target : torch.Tensor
|
|
20
|
+
Target patches.
|
|
21
|
+
|
|
18
22
|
Returns
|
|
19
23
|
-------
|
|
20
24
|
torch.Tensor
|
|
21
25
|
Loss value.
|
|
22
26
|
"""
|
|
23
27
|
loss = MSELoss()
|
|
24
|
-
return loss(
|
|
28
|
+
return loss(source, target)
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
def n2v_loss(
|
|
@@ -34,9 +38,9 @@ def n2v_loss(
|
|
|
34
38
|
|
|
35
39
|
Parameters
|
|
36
40
|
----------
|
|
37
|
-
|
|
41
|
+
manipulated_patches : torch.Tensor
|
|
38
42
|
Patches with manipulated pixels.
|
|
39
|
-
|
|
43
|
+
original_patches : torch.Tensor
|
|
40
44
|
Noisy patches.
|
|
41
45
|
masks : torch.Tensor
|
|
42
46
|
Array containing masked pixel locations.
|
|
File without changes
|