careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Tiled patching utilities."""
|
|
2
|
+
|
|
1
3
|
import itertools
|
|
2
4
|
from typing import Generator, List, Tuple, Union
|
|
3
5
|
|
|
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
|
|
|
8
10
|
|
|
9
11
|
def _compute_crop_and_stitch_coords_1d(
|
|
10
12
|
axis_size: int, tile_size: int, overlap: int
|
|
11
|
-
) -> Tuple[List[Tuple[int,
|
|
13
|
+
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
|
|
12
14
|
"""
|
|
13
15
|
Compute the coordinates of each tile along an axis, given the overlap.
|
|
14
16
|
|
|
@@ -43,9 +45,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
43
45
|
stitch_coords.append(
|
|
44
46
|
(
|
|
45
47
|
i + overlap // 2 if i > 0 else 0,
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
(
|
|
49
|
+
i + tile_size - overlap // 2
|
|
50
|
+
if crop_coords[-1][1] < axis_size
|
|
51
|
+
else axis_size
|
|
52
|
+
),
|
|
49
53
|
)
|
|
50
54
|
)
|
|
51
55
|
|
|
@@ -53,9 +57,11 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
53
57
|
overlap_crop_coords.append(
|
|
54
58
|
(
|
|
55
59
|
overlap // 2 if i > 0 else 0,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
60
|
+
(
|
|
61
|
+
tile_size - overlap // 2
|
|
62
|
+
if crop_coords[-1][1] < axis_size
|
|
63
|
+
else tile_size
|
|
64
|
+
),
|
|
59
65
|
)
|
|
60
66
|
)
|
|
61
67
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Training and validation Lightning data modules."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pytorch_lightning as L
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
9
|
|
|
10
10
|
from careamics.config import DataConfig
|
|
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
95
95
|
Batch size.
|
|
96
96
|
use_in_memory : bool
|
|
97
97
|
Whether to use in memory dataset if possible.
|
|
98
|
-
train_data : Union[Path,
|
|
98
|
+
train_data : Union[Path, np.ndarray]
|
|
99
99
|
Training data.
|
|
100
|
-
val_data : Optional[Union[Path,
|
|
100
|
+
val_data : Optional[Union[Path, np.ndarray]]
|
|
101
101
|
Validation data.
|
|
102
|
-
train_data_target : Optional[Union[Path,
|
|
102
|
+
train_data_target : Optional[Union[Path, np.ndarray]]
|
|
103
103
|
Training target data.
|
|
104
|
-
val_data_target : Optional[Union[Path,
|
|
104
|
+
val_data_target : Optional[Union[Path, np.ndarray]]
|
|
105
105
|
Validation target data.
|
|
106
106
|
val_percentage : float
|
|
107
107
|
Percentage of the training data to use for validation, if no validation data is
|
|
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
217
217
|
)
|
|
218
218
|
|
|
219
219
|
# configuration
|
|
220
|
-
self.data_config = data_config
|
|
221
|
-
self.data_type = data_config.data_type
|
|
222
|
-
self.batch_size = data_config.batch_size
|
|
223
|
-
self.use_in_memory = use_in_memory
|
|
220
|
+
self.data_config: DataConfig = data_config
|
|
221
|
+
self.data_type: str = data_config.data_type
|
|
222
|
+
self.batch_size: int = data_config.batch_size
|
|
223
|
+
self.use_in_memory: bool = use_in_memory
|
|
224
|
+
|
|
225
|
+
# data: make data Path or np.ndarray, use type annotations for mypy
|
|
226
|
+
self.train_data: Union[Path, np.ndarray] = (
|
|
227
|
+
Path(train_data) if isinstance(train_data, str) else train_data
|
|
228
|
+
)
|
|
224
229
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
230
|
+
self.val_data: Union[Path, np.ndarray] = (
|
|
231
|
+
Path(val_data) if isinstance(val_data, str) else val_data
|
|
232
|
+
)
|
|
228
233
|
|
|
229
|
-
self.train_data_target =
|
|
230
|
-
|
|
234
|
+
self.train_data_target: Union[Path, np.ndarray] = (
|
|
235
|
+
Path(train_data_target)
|
|
236
|
+
if isinstance(train_data_target, str)
|
|
237
|
+
else train_data_target
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.val_data_target: Union[Path, np.ndarray] = (
|
|
241
|
+
Path(val_data_target)
|
|
242
|
+
if isinstance(val_data_target, str)
|
|
243
|
+
else val_data_target
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# validation split
|
|
231
247
|
self.val_percentage = val_percentage
|
|
232
248
|
self.val_minimum_split = val_minimum_split
|
|
233
249
|
|
|
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
241
257
|
elif data_config.data_type != SupportedData.ARRAY:
|
|
242
258
|
self.read_source_func = get_read_func(data_config.data_type)
|
|
243
259
|
|
|
244
|
-
self.extension_filter = extension_filter
|
|
260
|
+
self.extension_filter: str = extension_filter
|
|
245
261
|
|
|
246
262
|
# Pytorch dataloader parameters
|
|
247
|
-
self.dataloader_params = (
|
|
263
|
+
self.dataloader_params: Dict[str, Any] = (
|
|
248
264
|
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
249
265
|
)
|
|
250
266
|
|
|
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
309
325
|
"""
|
|
310
326
|
# if numpy array
|
|
311
327
|
if self.data_type == SupportedData.ARRAY:
|
|
328
|
+
# mypy checks
|
|
329
|
+
assert isinstance(self.train_data, np.ndarray)
|
|
330
|
+
if self.train_data_target is not None:
|
|
331
|
+
assert isinstance(self.train_data_target, np.ndarray)
|
|
332
|
+
|
|
312
333
|
# train dataset
|
|
313
334
|
self.train_dataset: DatasetType = InMemoryDataset(
|
|
314
335
|
data_config=self.data_config,
|
|
315
336
|
inputs=self.train_data,
|
|
316
|
-
|
|
337
|
+
input_target=self.train_data_target,
|
|
317
338
|
)
|
|
318
339
|
|
|
319
340
|
# validation dataset
|
|
320
341
|
if self.val_data is not None:
|
|
342
|
+
# mypy checks
|
|
343
|
+
assert isinstance(self.val_data, np.ndarray)
|
|
344
|
+
if self.val_data_target is not None:
|
|
345
|
+
assert isinstance(self.val_data_target, np.ndarray)
|
|
346
|
+
|
|
321
347
|
# create its own dataset
|
|
322
348
|
self.val_dataset: DatasetType = InMemoryDataset(
|
|
323
349
|
data_config=self.data_config,
|
|
324
350
|
inputs=self.val_data,
|
|
325
|
-
|
|
351
|
+
input_target=self.val_data_target,
|
|
326
352
|
)
|
|
327
353
|
else:
|
|
328
354
|
# extract validation from the training patches
|
|
@@ -341,9 +367,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
341
367
|
self.train_dataset = InMemoryDataset(
|
|
342
368
|
data_config=self.data_config,
|
|
343
369
|
inputs=self.train_files,
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
370
|
+
input_target=(
|
|
371
|
+
self.train_target_files if self.train_data_target else None
|
|
372
|
+
),
|
|
347
373
|
read_source_func=self.read_source_func,
|
|
348
374
|
)
|
|
349
375
|
|
|
@@ -352,9 +378,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
352
378
|
self.val_dataset = InMemoryDataset(
|
|
353
379
|
data_config=self.data_config,
|
|
354
380
|
inputs=self.val_files,
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
381
|
+
input_target=(
|
|
382
|
+
self.val_target_files if self.val_data_target else None
|
|
383
|
+
),
|
|
358
384
|
read_source_func=self.read_source_func,
|
|
359
385
|
)
|
|
360
386
|
else:
|
|
@@ -370,9 +396,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
370
396
|
self.train_dataset = PathIterableDataset(
|
|
371
397
|
data_config=self.data_config,
|
|
372
398
|
src_files=self.train_files,
|
|
373
|
-
target_files=
|
|
374
|
-
|
|
375
|
-
|
|
399
|
+
target_files=(
|
|
400
|
+
self.train_target_files if self.train_data_target else None
|
|
401
|
+
),
|
|
376
402
|
read_source_func=self.read_source_func,
|
|
377
403
|
)
|
|
378
404
|
|
|
@@ -382,9 +408,9 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
382
408
|
self.val_dataset = PathIterableDataset(
|
|
383
409
|
data_config=self.data_config,
|
|
384
410
|
src_files=self.val_files,
|
|
385
|
-
target_files=
|
|
386
|
-
|
|
387
|
-
|
|
411
|
+
target_files=(
|
|
412
|
+
self.val_target_files if self.val_data_target else None
|
|
413
|
+
),
|
|
388
414
|
read_source_func=self.read_source_func,
|
|
389
415
|
)
|
|
390
416
|
elif len(self.train_files) <= self.val_minimum_split:
|
|
@@ -452,8 +478,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
452
478
|
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
453
479
|
not compatible with supervised training. The default transformations applied to the
|
|
454
480
|
training patches are defined in `careamics.config.data_model`. To use different
|
|
455
|
-
transformations, pass a list of transforms
|
|
456
|
-
`transforms` parameter. See examples for more details.
|
|
481
|
+
transformations, pass a list of transforms. See examples for more details.
|
|
457
482
|
|
|
458
483
|
By default, CAREamics only supports types defined in
|
|
459
484
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -488,7 +513,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
488
513
|
Batch size.
|
|
489
514
|
val_data : Optional[Union[str, Path]], optional
|
|
490
515
|
Validation data, by default None.
|
|
491
|
-
transforms :
|
|
516
|
+
transforms : List[TRANSFORMS_UNION], optional
|
|
492
517
|
List of transforms to apply to training patches. If None, default transforms
|
|
493
518
|
are applied.
|
|
494
519
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -584,7 +609,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
584
609
|
axes: str,
|
|
585
610
|
batch_size: int,
|
|
586
611
|
val_data: Optional[Union[str, Path]] = None,
|
|
587
|
-
transforms: Optional[
|
|
612
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
588
613
|
train_target_data: Optional[Union[str, Path]] = None,
|
|
589
614
|
val_target_data: Optional[Union[str, Path]] = None,
|
|
590
615
|
read_source_func: Optional[Callable] = None,
|
|
@@ -617,8 +642,8 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
617
642
|
In particular, N2V requires a specific transformation (N2V manipulates), which
|
|
618
643
|
is not compatible with supervised training. The default transformations applied
|
|
619
644
|
to the training patches are defined in `careamics.config.data_model`. To use
|
|
620
|
-
different transformations, pass a list of transforms
|
|
621
|
-
|
|
645
|
+
different transformations, pass a list of transforms. See examples for more
|
|
646
|
+
details.
|
|
622
647
|
|
|
623
648
|
By default, CAREamics only supports types defined in
|
|
624
649
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -655,7 +680,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
655
680
|
Batch size.
|
|
656
681
|
val_data : Optional[Union[str, Path]], optional
|
|
657
682
|
Validation data, by default None.
|
|
658
|
-
transforms : Optional[
|
|
683
|
+
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
659
684
|
List of transforms to apply to training patches. If None, default transforms
|
|
660
685
|
are applied.
|
|
661
686
|
train_target_data : Optional[Union[str, Path]], optional
|
|
@@ -709,10 +734,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
709
734
|
self.data_config = DataConfig(**data_dict)
|
|
710
735
|
|
|
711
736
|
# N2V specific checks, N2V, structN2V, and transforms
|
|
712
|
-
if (
|
|
713
|
-
self.data_config.has_transform_list()
|
|
714
|
-
and self.data_config.has_n2v_manipulate()
|
|
715
|
-
):
|
|
737
|
+
if self.data_config.has_n2v_manipulate():
|
|
716
738
|
# there is not target, n2v2 and structN2V can be changed
|
|
717
739
|
if train_target_data is None:
|
|
718
740
|
self.data_config.set_N2V2(use_n2v2)
|
careamics/lightning_module.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Optional, Union
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
|
|
|
24
26
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
25
27
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
26
28
|
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
32
|
+
Algorithm configuration.
|
|
33
|
+
|
|
27
34
|
Attributes
|
|
28
35
|
----------
|
|
29
36
|
model : nn.Module
|
|
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
39
46
|
"""
|
|
40
47
|
|
|
41
48
|
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
42
|
-
"""
|
|
43
|
-
CAREamics Lightning module.
|
|
49
|
+
"""Lightning module for CAREamics.
|
|
44
50
|
|
|
45
51
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
46
52
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
@@ -162,7 +168,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
162
168
|
mean=self._trainer.datamodule.predict_dataset.mean,
|
|
163
169
|
std=self._trainer.datamodule.predict_dataset.std,
|
|
164
170
|
)
|
|
165
|
-
denormalized_output = denorm(
|
|
171
|
+
denormalized_output, _ = denorm(patch=output)
|
|
166
172
|
|
|
167
173
|
if len(aux) > 0:
|
|
168
174
|
return denormalized_output, aux
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pytorch_lightning as L
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
9
|
from torch.utils.data.dataloader import default_collate
|
|
10
10
|
|
|
@@ -39,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
39
39
|
|
|
40
40
|
Parameters
|
|
41
41
|
----------
|
|
42
|
-
batch :
|
|
42
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
43
43
|
Batch of tiles.
|
|
44
44
|
|
|
45
45
|
Returns
|
|
@@ -257,14 +257,13 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
257
257
|
|
|
258
258
|
The default transformations applied to the images are defined in
|
|
259
259
|
`careamics.config.inference_model`. To use different transformations, pass a list
|
|
260
|
-
of transforms
|
|
260
|
+
of transforms. See examples
|
|
261
261
|
for more details.
|
|
262
262
|
|
|
263
263
|
The `mean` and `std` parameters are only used if Normalization is defined either
|
|
264
|
-
in the default transformations or in the `transforms` parameter
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
to this method.
|
|
264
|
+
in the default transformations or in the `transforms` parameter. If you pass a
|
|
265
|
+
`Normalization` transform in a list as `transforms`, then the mean and std
|
|
266
|
+
parameters will be overwritten by those passed to this method.
|
|
268
267
|
|
|
269
268
|
By default, CAREamics only supports types defined in
|
|
270
269
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
@@ -276,6 +275,12 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
276
275
|
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
277
276
|
parameter.
|
|
278
277
|
|
|
278
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
279
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
280
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
281
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
282
|
+
happen in the Z dimension, consider padding your image.
|
|
283
|
+
|
|
279
284
|
Parameters
|
|
280
285
|
----------
|
|
281
286
|
pred_data : Union[str, Path, np.ndarray]
|
|
@@ -298,9 +303,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
298
303
|
Batch size.
|
|
299
304
|
tta_transforms : bool, optional
|
|
300
305
|
Use test time augmentation, by default True.
|
|
301
|
-
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
302
|
-
List of transforms to apply to prediction patches. If None, default
|
|
303
|
-
transforms are applied.
|
|
304
306
|
read_source_func : Optional[Callable], optional
|
|
305
307
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
306
308
|
default None.
|
|
@@ -321,7 +323,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
321
323
|
axes: str = "YX",
|
|
322
324
|
batch_size: int = 1,
|
|
323
325
|
tta_transforms: bool = True,
|
|
324
|
-
transforms: Optional[Union[List, Compose]] = None,
|
|
325
326
|
read_source_func: Optional[Callable] = None,
|
|
326
327
|
extension_filter: str = "",
|
|
327
328
|
dataloader_params: Optional[dict] = None,
|
|
@@ -351,9 +352,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
351
352
|
Batch size.
|
|
352
353
|
tta_transforms : bool, optional
|
|
353
354
|
Use test time augmentation, by default True.
|
|
354
|
-
transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
|
|
355
|
-
List of transforms to apply to prediction patches. If None, default
|
|
356
|
-
transforms are applied.
|
|
357
355
|
read_source_func : Optional[Callable], optional
|
|
358
356
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
359
357
|
default None.
|
|
@@ -364,7 +362,7 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
364
362
|
"""
|
|
365
363
|
if dataloader_params is None:
|
|
366
364
|
dataloader_params = {}
|
|
367
|
-
prediction_dict = {
|
|
365
|
+
prediction_dict: Dict[str, Any] = {
|
|
368
366
|
"data_type": data_type,
|
|
369
367
|
"tile_size": tile_size,
|
|
370
368
|
"tile_overlap": tile_overlap,
|
|
@@ -373,12 +371,9 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
373
371
|
"std": std,
|
|
374
372
|
"tta": tta_transforms,
|
|
375
373
|
"batch_size": batch_size,
|
|
374
|
+
"transforms": [],
|
|
376
375
|
}
|
|
377
376
|
|
|
378
|
-
# if transforms are passed (otherwise it will use the default ones)
|
|
379
|
-
if transforms is not None:
|
|
380
|
-
prediction_dict["transforms"] = transforms
|
|
381
|
-
|
|
382
377
|
# validate configuration
|
|
383
378
|
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
384
379
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Lithning prediction loop allowing tiling."""
|
|
2
|
+
|
|
1
3
|
from typing import Optional
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -18,14 +20,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
18
20
|
"""
|
|
19
21
|
|
|
20
22
|
def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
|
|
21
|
-
"""
|
|
22
|
-
Calls `on_predict_epoch_end` hook.
|
|
23
|
+
"""Call `on_predict_epoch_end` hook.
|
|
23
24
|
|
|
24
25
|
Adapted from the parent method.
|
|
25
26
|
|
|
26
27
|
Returns
|
|
27
28
|
-------
|
|
28
|
-
|
|
29
|
+
Optional[_PREDICT_OUTPUT]
|
|
30
|
+
Prediction output.
|
|
29
31
|
"""
|
|
30
32
|
trainer = self.trainer
|
|
31
33
|
call._call_callback_hooks(trainer, "on_predict_epoch_end")
|
|
@@ -45,15 +47,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
45
47
|
|
|
46
48
|
@_no_grad_context
|
|
47
49
|
def run(self) -> Optional[_PREDICT_OUTPUT]:
|
|
48
|
-
"""
|
|
49
|
-
Runs the prediction loop.
|
|
50
|
+
"""Run the prediction loop.
|
|
50
51
|
|
|
51
52
|
Adapted from the parent method in order to stitch the predictions.
|
|
52
53
|
|
|
53
54
|
Returns
|
|
54
55
|
-------
|
|
55
56
|
Optional[_PREDICT_OUTPUT]
|
|
56
|
-
Prediction output
|
|
57
|
+
Prediction output.
|
|
57
58
|
"""
|
|
58
59
|
self.setup_data()
|
|
59
60
|
if self.skip:
|
|
@@ -86,6 +87,7 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
86
87
|
|
|
87
88
|
########################################################
|
|
88
89
|
################ CAREamics specific code ###############
|
|
90
|
+
# TODO: next line is not compatible with muSplit
|
|
89
91
|
is_tiled = len(self.predictions[batch_idx]) == 2
|
|
90
92
|
if is_tiled:
|
|
91
93
|
# extract the last tile flag and the coordinates (crop and stitch)
|
careamics/losses/__init__.py
CHANGED
careamics/losses/loss_factory.py
CHANGED
|
@@ -3,6 +3,7 @@ Loss factory module.
|
|
|
3
3
|
|
|
4
4
|
This module contains a factory function for creating loss functions.
|
|
5
5
|
"""
|
|
6
|
+
|
|
6
7
|
from typing import Callable, Union
|
|
7
8
|
|
|
8
9
|
from ..config.support import SupportedLoss
|
|
@@ -16,7 +17,7 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
|
16
17
|
|
|
17
18
|
Parameters
|
|
18
19
|
----------
|
|
19
|
-
loss: SupportedLoss
|
|
20
|
+
loss : Union[SupportedLoss, str]
|
|
20
21
|
Requested loss.
|
|
21
22
|
|
|
22
23
|
Returns
|
careamics/losses/losses.py
CHANGED
|
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
|
|
9
|
-
# TODO if we are only using the DiceLoss, can we just implement it?
|
|
10
|
-
# from segmentation_models_pytorch.losses import DiceLoss
|
|
11
8
|
from torch.nn import L1Loss, MSELoss
|
|
12
9
|
|
|
13
10
|
|
|
14
|
-
def mse_loss(
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
15
12
|
"""
|
|
16
13
|
Mean squared error loss.
|
|
17
14
|
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
source : torch.Tensor
|
|
18
|
+
Source patches.
|
|
19
|
+
target : torch.Tensor
|
|
20
|
+
Target patches.
|
|
21
|
+
|
|
18
22
|
Returns
|
|
19
23
|
-------
|
|
20
24
|
torch.Tensor
|
|
21
25
|
Loss value.
|
|
22
26
|
"""
|
|
23
27
|
loss = MSELoss()
|
|
24
|
-
return loss(
|
|
28
|
+
return loss(source, target)
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
def n2v_loss(
|
|
@@ -34,9 +38,9 @@ def n2v_loss(
|
|
|
34
38
|
|
|
35
39
|
Parameters
|
|
36
40
|
----------
|
|
37
|
-
|
|
41
|
+
manipulated_patches : torch.Tensor
|
|
38
42
|
Patches with manipulated pixels.
|
|
39
|
-
|
|
43
|
+
original_patches : torch.Tensor
|
|
40
44
|
Noisy patches.
|
|
41
45
|
masks : torch.Tensor
|
|
42
46
|
Array containing masked pixel locations.
|
careamics/model_io/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Functions used to create a README.md file for BMZ export."""
|
|
2
|
+
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Optional
|
|
4
5
|
|
|
@@ -117,4 +118,4 @@ def readme_factory(
|
|
|
117
118
|
|
|
118
119
|
readme.write_text("".join(description))
|
|
119
120
|
|
|
120
|
-
|
|
121
|
+
return readme.absolute()
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Function to export to the BioImage Model Zoo format."""
|
|
2
|
+
|
|
2
3
|
import tempfile
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import List, Optional, Tuple, Union
|
|
@@ -103,9 +104,9 @@ def export_to_bmz(
|
|
|
103
104
|
authors : List[dict]
|
|
104
105
|
Authors of the model.
|
|
105
106
|
input_array : np.ndarray
|
|
106
|
-
Input array.
|
|
107
|
+
Input array, should not have been normalized.
|
|
107
108
|
output_array : np.ndarray
|
|
108
|
-
Output array.
|
|
109
|
+
Output array, should have been denormalized.
|
|
109
110
|
channel_names : Optional[List[str]], optional
|
|
110
111
|
Channel names, by default None.
|
|
111
112
|
data_description : Optional[str], optional
|
|
@@ -177,7 +178,7 @@ def export_to_bmz(
|
|
|
177
178
|
)
|
|
178
179
|
|
|
179
180
|
# test model description
|
|
180
|
-
summary: ValidationSummary = test_model(model_description)
|
|
181
|
+
summary: ValidationSummary = test_model(model_description, decimal=2)
|
|
181
182
|
if summary.status == "failed":
|
|
182
183
|
raise ValueError(f"Model description test failed: {summary}")
|
|
183
184
|
|