careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,23 +1,12 @@
|
|
|
1
|
+
"""Transforms supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from careamics.utils import BaseEnum
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SupportedTransform(str, BaseEnum):
|
|
5
|
-
"""Transforms officially supported by CAREamics.
|
|
6
|
-
|
|
7
|
-
- Flip: from Albumentations, randomly flip the input horizontally, vertically or
|
|
8
|
-
both, parameter `p` can be used to set the probability to apply the transform.
|
|
9
|
-
- XYRandomRotate90: #TODO
|
|
10
|
-
- Normalize # TODO add details, in particular about the parameters
|
|
11
|
-
- ManipulateN2V # TODO add details, in particular about the parameters
|
|
12
|
-
- NDFlip
|
|
13
|
-
|
|
14
|
-
Note that while any Albumentations (see https://albumentations.ai/) transform can be
|
|
15
|
-
used in CAREamics, no check are implemented to verify the compatibility of any other
|
|
16
|
-
transforms than the ones officially supported.
|
|
17
|
-
"""
|
|
7
|
+
"""Transforms officially supported by CAREamics."""
|
|
18
8
|
|
|
19
|
-
|
|
9
|
+
XY_FLIP = "XYFlip"
|
|
20
10
|
XY_RANDOM_ROTATE90 = "XYRandomRotate90"
|
|
21
11
|
NORMALIZE = "Normalize"
|
|
22
12
|
N2V_MANIPULATE = "N2VManipulate"
|
|
23
|
-
# CUSTOM = "Custom"
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"N2VManipulateModel",
|
|
5
|
-
"
|
|
5
|
+
"XYFlipModel",
|
|
6
6
|
"NormalizeModel",
|
|
7
7
|
"XYRandomRotate90Model",
|
|
8
|
+
"XorYFlipModel",
|
|
8
9
|
]
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
from .n2v_manipulate_model import N2VManipulateModel
|
|
12
|
-
from .nd_flip_model import NDFlipModel
|
|
13
13
|
from .normalize_model import NormalizeModel
|
|
14
|
+
from .xy_flip_model import XYFlipModel
|
|
14
15
|
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Pydantic model for the XYFlip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlipModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent XYFlip transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYFlip"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYFlip"] = "XYFlip"
|
|
29
|
+
flip_x: bool = Field(
|
|
30
|
+
True,
|
|
31
|
+
description="Whether to flip along the X axis.",
|
|
32
|
+
)
|
|
33
|
+
flip_y: bool = Field(
|
|
34
|
+
True,
|
|
35
|
+
description="Whether to flip along the Y axis.",
|
|
36
|
+
)
|
|
37
|
+
p: float = Field(
|
|
38
|
+
0.5,
|
|
39
|
+
description="Probability of applying the transform.",
|
|
40
|
+
ge=0,
|
|
41
|
+
le=1,
|
|
42
|
+
)
|
|
43
|
+
seed: Optional[int] = None
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
3
4
|
|
|
4
5
|
from pydantic import ConfigDict, Field
|
|
5
6
|
|
|
@@ -8,16 +9,16 @@ from .transform_model import TransformModel
|
|
|
8
9
|
|
|
9
10
|
class XYRandomRotate90Model(TransformModel):
|
|
10
11
|
"""
|
|
11
|
-
Pydantic model used to represent
|
|
12
|
+
Pydantic model used to represent the XY random 90 degree rotation transformation.
|
|
12
13
|
|
|
13
14
|
Attributes
|
|
14
15
|
----------
|
|
15
16
|
name : Literal["XYRandomRotate90"]
|
|
16
17
|
Name of the transformation.
|
|
17
18
|
p : float
|
|
18
|
-
Probability of applying the
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
21
22
|
"""
|
|
22
23
|
|
|
23
24
|
model_config = ConfigDict(
|
|
@@ -25,5 +26,10 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
25
26
|
)
|
|
26
27
|
|
|
27
28
|
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
28
|
-
p: float = Field(
|
|
29
|
-
|
|
29
|
+
p: float = Field(
|
|
30
|
+
0.5,
|
|
31
|
+
description="Probability of applying the transform.",
|
|
32
|
+
ge=0,
|
|
33
|
+
le=1,
|
|
34
|
+
)
|
|
35
|
+
seed: Optional[int] = None
|
careamics/conftest.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
See https://sybil.readthedocs.io/en/latest/use.html#pytest
|
|
4
4
|
"""
|
|
5
|
+
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
|
|
7
8
|
import pytest
|
|
@@ -13,6 +14,18 @@ from sybil.parsers.doctest import DocTestParser
|
|
|
13
14
|
|
|
14
15
|
@pytest.fixture(scope="module")
|
|
15
16
|
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
17
|
+
"""Fixture used in doctest to create a temporary directory.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tmpdir_factory : TempPathFactory
|
|
22
|
+
Temporary path factory from pytest.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
Temporary directory path.
|
|
28
|
+
"""
|
|
16
29
|
return tmpdir_factory.mktemp("my_path")
|
|
17
30
|
|
|
18
31
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Dataset utilities."""
|
|
2
|
+
|
|
2
3
|
from typing import List, Tuple
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -16,12 +17,12 @@ def _get_shape_order(
|
|
|
16
17
|
|
|
17
18
|
Parameters
|
|
18
19
|
----------
|
|
19
|
-
shape_in : Tuple
|
|
20
|
+
shape_in : Tuple[int, ...]
|
|
20
21
|
Input shape.
|
|
21
|
-
ref_axes : str
|
|
22
|
-
Reference axes.
|
|
23
22
|
axes_in : str
|
|
24
23
|
Input axes.
|
|
24
|
+
ref_axes : str
|
|
25
|
+
Reference axes.
|
|
25
26
|
|
|
26
27
|
Returns
|
|
27
28
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""File utilities."""
|
|
2
|
+
|
|
1
3
|
from fnmatch import fnmatch
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
from typing import List, Union
|
|
@@ -11,8 +13,7 @@ logger = get_logger(__name__)
|
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def get_files_size(files: List[Path]) -> float:
|
|
14
|
-
"""
|
|
15
|
-
Get files size in MB.
|
|
16
|
+
"""Get files size in MB.
|
|
16
17
|
|
|
17
18
|
Parameters
|
|
18
19
|
----------
|
|
@@ -32,7 +33,7 @@ def list_files(
|
|
|
32
33
|
data_type: Union[str, SupportedData],
|
|
33
34
|
extension_filter: str = "",
|
|
34
35
|
) -> List[Path]:
|
|
35
|
-
"""
|
|
36
|
+
"""Create a recursive list of files in `data_path`.
|
|
36
37
|
|
|
37
38
|
If `data_path` is a file, its name is validated against the `data_type` using
|
|
38
39
|
`fnmatch`, and the method returns `data_path` itself.
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Funtions to read tiff images."""
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from fnmatch import fnmatch
|
|
3
5
|
from pathlib import Path
|
|
@@ -19,8 +21,10 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
19
21
|
----------
|
|
20
22
|
file_path : Path
|
|
21
23
|
Path to a file.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
*args : list
|
|
25
|
+
Additional arguments.
|
|
26
|
+
**kwargs : dict
|
|
27
|
+
Additional keyword arguments.
|
|
24
28
|
|
|
25
29
|
Returns
|
|
26
30
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Function to read zarr images."""
|
|
2
|
+
|
|
1
3
|
from typing import Union
|
|
2
4
|
|
|
3
5
|
from zarr import Group, core, hierarchy, storage
|
|
@@ -6,26 +8,28 @@ from zarr import Group, core, hierarchy, storage
|
|
|
6
8
|
def read_zarr(
|
|
7
9
|
zarr_source: Group, axes: str
|
|
8
10
|
) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
|
|
9
|
-
"""
|
|
11
|
+
"""Read a file and returns a pointer.
|
|
10
12
|
|
|
11
13
|
Parameters
|
|
12
14
|
----------
|
|
13
|
-
|
|
14
|
-
|
|
15
|
+
zarr_source : Group
|
|
16
|
+
Zarr storage.
|
|
17
|
+
axes : str
|
|
18
|
+
Axes of the data.
|
|
15
19
|
|
|
16
20
|
Returns
|
|
17
21
|
-------
|
|
18
22
|
np.ndarray
|
|
19
|
-
Pointer to zarr storage
|
|
23
|
+
Pointer to zarr storage.
|
|
20
24
|
|
|
21
25
|
Raises
|
|
22
26
|
------
|
|
23
27
|
ValueError, OSError
|
|
24
|
-
if a file is not a valid tiff or damaged
|
|
28
|
+
if a file is not a valid tiff or damaged.
|
|
25
29
|
ValueError
|
|
26
|
-
if data dimensions are not 2, 3 or 4
|
|
30
|
+
if data dimensions are not 2, 3 or 4.
|
|
27
31
|
ValueError
|
|
28
|
-
if axes parameter from config is not consistent with data dimensions
|
|
32
|
+
if axes parameter from config is not consistent with data dimensions.
|
|
29
33
|
"""
|
|
30
34
|
if isinstance(zarr_source, hierarchy.Group):
|
|
31
35
|
array = zarr_source[0]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""In-memory dataset module."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
import copy
|
|
@@ -8,11 +9,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from torch.utils.data import Dataset
|
|
10
11
|
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
11
14
|
from ..config import DataConfig, InferenceConfig
|
|
12
15
|
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
13
17
|
from ..utils.logging import get_logger
|
|
14
18
|
from .dataset_utils import read_tiff, reshape_array
|
|
15
|
-
from .patching.patch_transform import get_patch_transform
|
|
16
19
|
from .patching.patching import (
|
|
17
20
|
prepare_patches_supervised,
|
|
18
21
|
prepare_patches_supervised_array,
|
|
@@ -25,24 +28,49 @@ logger = get_logger(__name__)
|
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
class InMemoryDataset(Dataset):
|
|
28
|
-
"""Dataset storing data in memory and allowing generating patches from it.
|
|
31
|
+
"""Dataset storing data in memory and allowing generating patches from it.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : DataConfig
|
|
36
|
+
Data configuration.
|
|
37
|
+
inputs : Union[np.ndarray, List[Path]]
|
|
38
|
+
Input data.
|
|
39
|
+
input_target : Optional[Union[np.ndarray, List[Path]]], optional
|
|
40
|
+
Target data, by default None.
|
|
41
|
+
read_source_func : Callable, optional
|
|
42
|
+
Read source function for custom types, by default read_tiff.
|
|
43
|
+
**kwargs : Any
|
|
44
|
+
Additional keyword arguments, unused.
|
|
45
|
+
"""
|
|
29
46
|
|
|
30
47
|
def __init__(
|
|
31
48
|
self,
|
|
32
49
|
data_config: DataConfig,
|
|
33
50
|
inputs: Union[np.ndarray, List[Path]],
|
|
34
|
-
|
|
51
|
+
input_target: Optional[Union[np.ndarray, List[Path]]] = None,
|
|
35
52
|
read_source_func: Callable = read_tiff,
|
|
36
53
|
**kwargs: Any,
|
|
37
54
|
) -> None:
|
|
38
55
|
"""
|
|
39
56
|
Constructor.
|
|
40
57
|
|
|
41
|
-
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
data_config : DataConfig
|
|
61
|
+
Data configuration.
|
|
62
|
+
inputs : Union[np.ndarray, List[Path]]
|
|
63
|
+
Input data.
|
|
64
|
+
input_target : Optional[Union[np.ndarray, List[Path]]], optional
|
|
65
|
+
Target data, by default None.
|
|
66
|
+
read_source_func : Callable, optional
|
|
67
|
+
Read source function for custom types, by default read_tiff.
|
|
68
|
+
**kwargs : Any
|
|
69
|
+
Additional keyword arguments, unused.
|
|
42
70
|
"""
|
|
43
71
|
self.data_config = data_config
|
|
44
72
|
self.inputs = inputs
|
|
45
|
-
self.
|
|
73
|
+
self.input_targets = input_target
|
|
46
74
|
self.axes = self.data_config.axes
|
|
47
75
|
self.patch_size = self.data_config.patch_size
|
|
48
76
|
|
|
@@ -50,28 +78,25 @@ class InMemoryDataset(Dataset):
|
|
|
50
78
|
self.read_source_func = read_source_func
|
|
51
79
|
|
|
52
80
|
# Generate patches
|
|
53
|
-
supervised = self.
|
|
54
|
-
|
|
81
|
+
supervised = self.input_targets is not None
|
|
82
|
+
patch_data = self._prepare_patches(supervised)
|
|
55
83
|
|
|
56
84
|
# Add results to members
|
|
57
|
-
self.
|
|
85
|
+
self.patches, self.patch_targets, computed_mean, computed_std = patch_data
|
|
58
86
|
|
|
59
87
|
if not self.data_config.mean or not self.data_config.std:
|
|
60
88
|
self.mean, self.std = computed_mean, computed_std
|
|
61
89
|
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
62
90
|
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
-
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
91
|
+
# update mean and std in configuration
|
|
92
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
93
|
+
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
68
94
|
else:
|
|
69
95
|
self.mean, self.std = self.data_config.mean, self.data_config.std
|
|
70
96
|
|
|
71
97
|
# get transforms
|
|
72
|
-
self.patch_transform =
|
|
73
|
-
|
|
74
|
-
with_target=self.data_target is not None,
|
|
98
|
+
self.patch_transform = Compose(
|
|
99
|
+
transform_list=self.data_config.transforms,
|
|
75
100
|
)
|
|
76
101
|
|
|
77
102
|
def _prepare_patches(
|
|
@@ -92,18 +117,18 @@ class InMemoryDataset(Dataset):
|
|
|
92
117
|
"""
|
|
93
118
|
if supervised:
|
|
94
119
|
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
95
|
-
self.
|
|
120
|
+
self.input_targets, np.ndarray
|
|
96
121
|
):
|
|
97
122
|
return prepare_patches_supervised_array(
|
|
98
123
|
self.inputs,
|
|
99
124
|
self.axes,
|
|
100
|
-
self.
|
|
125
|
+
self.input_targets,
|
|
101
126
|
self.patch_size,
|
|
102
127
|
)
|
|
103
|
-
elif isinstance(self.inputs, list) and isinstance(self.
|
|
128
|
+
elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
|
|
104
129
|
return prepare_patches_supervised(
|
|
105
130
|
self.inputs,
|
|
106
|
-
self.
|
|
131
|
+
self.input_targets,
|
|
107
132
|
self.axes,
|
|
108
133
|
self.patch_size,
|
|
109
134
|
self.read_source_func,
|
|
@@ -112,7 +137,7 @@ class InMemoryDataset(Dataset):
|
|
|
112
137
|
raise ValueError(
|
|
113
138
|
f"Data and target must be of the same type, either both numpy "
|
|
114
139
|
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
115
|
-
f"and {type(self.
|
|
140
|
+
f"and {type(self.input_targets)} (target)."
|
|
116
141
|
)
|
|
117
142
|
else:
|
|
118
143
|
if isinstance(self.inputs, np.ndarray):
|
|
@@ -138,9 +163,9 @@ class InMemoryDataset(Dataset):
|
|
|
138
163
|
int
|
|
139
164
|
Length of the dataset.
|
|
140
165
|
"""
|
|
141
|
-
return len(self.
|
|
166
|
+
return len(self.patches)
|
|
142
167
|
|
|
143
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray]:
|
|
168
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
|
|
144
169
|
"""
|
|
145
170
|
Return the patch corresponding to the provided index.
|
|
146
171
|
|
|
@@ -159,40 +184,17 @@ class InMemoryDataset(Dataset):
|
|
|
159
184
|
ValueError
|
|
160
185
|
If dataset mean and std are not set.
|
|
161
186
|
"""
|
|
162
|
-
patch = self.
|
|
187
|
+
patch = self.patches[index]
|
|
163
188
|
|
|
164
189
|
# if there is a target
|
|
165
|
-
if self.
|
|
190
|
+
if self.patch_targets is not None:
|
|
166
191
|
# get target
|
|
167
|
-
target = self.
|
|
168
|
-
|
|
169
|
-
# Albumentations requires Channel last
|
|
170
|
-
c_patch = np.moveaxis(patch, 0, -1)
|
|
171
|
-
c_target = np.moveaxis(target, 0, -1)
|
|
192
|
+
target = self.patch_targets[index]
|
|
172
193
|
|
|
173
|
-
|
|
174
|
-
transformed = self.patch_transform(image=c_patch, target=c_target)
|
|
175
|
-
|
|
176
|
-
# move axes back
|
|
177
|
-
patch = np.moveaxis(transformed["image"], -1, 0)
|
|
178
|
-
target = np.moveaxis(transformed["target"], -1, 0)
|
|
179
|
-
|
|
180
|
-
return patch, target
|
|
194
|
+
return self.patch_transform(patch=patch, target=target)
|
|
181
195
|
|
|
182
196
|
elif self.data_config.has_n2v_manipulate():
|
|
183
|
-
|
|
184
|
-
patch = np.moveaxis(patch, 0, -1)
|
|
185
|
-
|
|
186
|
-
# Apply transforms
|
|
187
|
-
transformed_patch = self.patch_transform(image=patch)["image"]
|
|
188
|
-
manip_patch, patch, mask = transformed_patch
|
|
189
|
-
|
|
190
|
-
# move C axes back
|
|
191
|
-
manip_patch = np.moveaxis(manip_patch, -1, 0)
|
|
192
|
-
patch = np.moveaxis(patch, -1, 0)
|
|
193
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
194
|
-
|
|
195
|
-
return (manip_patch, patch, mask)
|
|
197
|
+
return self.patch_transform(patch=patch)
|
|
196
198
|
else:
|
|
197
199
|
raise ValueError(
|
|
198
200
|
"Something went wrong! No target provided (not supervised training) "
|
|
@@ -247,25 +249,25 @@ class InMemoryDataset(Dataset):
|
|
|
247
249
|
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
248
250
|
|
|
249
251
|
# extract patches
|
|
250
|
-
val_patches = self.
|
|
252
|
+
val_patches = self.patches[indices]
|
|
251
253
|
|
|
252
254
|
# remove patches from self.patch
|
|
253
|
-
self.
|
|
255
|
+
self.patches = np.delete(self.patches, indices, axis=0)
|
|
254
256
|
|
|
255
257
|
# same for targets
|
|
256
|
-
if self.
|
|
257
|
-
val_targets = self.
|
|
258
|
-
self.
|
|
258
|
+
if self.patch_targets is not None:
|
|
259
|
+
val_targets = self.patch_targets[indices]
|
|
260
|
+
self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
|
|
259
261
|
|
|
260
262
|
# clone the dataset
|
|
261
263
|
dataset = copy.deepcopy(self)
|
|
262
264
|
|
|
263
265
|
# reassign patches
|
|
264
|
-
dataset.
|
|
266
|
+
dataset.patches = val_patches
|
|
265
267
|
|
|
266
268
|
# reassign targets
|
|
267
|
-
if self.
|
|
268
|
-
dataset.
|
|
269
|
+
if self.patch_targets is not None:
|
|
270
|
+
dataset.patch_targets = val_targets
|
|
269
271
|
|
|
270
272
|
return dataset
|
|
271
273
|
|
|
@@ -274,7 +276,16 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
274
276
|
"""
|
|
275
277
|
Dataset storing data in memory and allowing generating patches from it.
|
|
276
278
|
|
|
277
|
-
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
prediction_config : InferenceConfig
|
|
282
|
+
Prediction configuration.
|
|
283
|
+
inputs : np.ndarray
|
|
284
|
+
Input data.
|
|
285
|
+
data_target : Optional[np.ndarray], optional
|
|
286
|
+
Target data, by default None.
|
|
287
|
+
read_source_func : Optional[Callable], optional
|
|
288
|
+
Read source function for custom types, by default read_tiff.
|
|
278
289
|
"""
|
|
279
290
|
|
|
280
291
|
def __init__(
|
|
@@ -288,10 +299,14 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
288
299
|
|
|
289
300
|
Parameters
|
|
290
301
|
----------
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
302
|
+
prediction_config : InferenceConfig
|
|
303
|
+
Prediction configuration.
|
|
304
|
+
inputs : np.ndarray
|
|
305
|
+
Input data.
|
|
306
|
+
data_target : Optional[np.ndarray], optional
|
|
307
|
+
Target data, by default None.
|
|
308
|
+
read_source_func : Optional[Callable], optional
|
|
309
|
+
Read source function for custom types, by default read_tiff.
|
|
295
310
|
|
|
296
311
|
Raises
|
|
297
312
|
------
|
|
@@ -318,9 +333,8 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
318
333
|
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
319
334
|
|
|
320
335
|
# get transforms
|
|
321
|
-
self.patch_transform =
|
|
322
|
-
|
|
323
|
-
with_target=self.data_target is not None,
|
|
336
|
+
self.patch_transform = Compose(
|
|
337
|
+
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
|
|
324
338
|
)
|
|
325
339
|
|
|
326
340
|
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
@@ -335,7 +349,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
335
349
|
# reshape array
|
|
336
350
|
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
337
351
|
|
|
338
|
-
if self.tiling:
|
|
352
|
+
if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
|
|
339
353
|
# generate patches, which returns a generator
|
|
340
354
|
patch_generator = extract_tiles(
|
|
341
355
|
arr=reshaped_sample,
|
|
@@ -379,13 +393,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
379
393
|
"""
|
|
380
394
|
tile_array, tile_info = self.data[index]
|
|
381
395
|
|
|
382
|
-
# Albumentations requires channel last, use the XArrayTile array
|
|
383
|
-
patch = np.moveaxis(tile_array, 0, -1)
|
|
384
|
-
|
|
385
396
|
# Apply transforms
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
# move C axes back
|
|
389
|
-
transformed_patch = np.moveaxis(transformed_patch, -1, 0)
|
|
397
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
390
398
|
|
|
391
|
-
return
|
|
399
|
+
return transformed_tile, tile_info
|