careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,23 +1,11 @@
|
|
|
1
|
+
"""Transforms supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from careamics.utils import BaseEnum
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SupportedTransform(str, BaseEnum):
|
|
5
|
-
"""Transforms officially supported by CAREamics.
|
|
6
|
-
|
|
7
|
-
- Flip: from Albumentations, randomly flip the input horizontally, vertically or
|
|
8
|
-
both, parameter `p` can be used to set the probability to apply the transform.
|
|
9
|
-
- XYRandomRotate90: #TODO
|
|
10
|
-
- Normalize # TODO add details, in particular about the parameters
|
|
11
|
-
- ManipulateN2V # TODO add details, in particular about the parameters
|
|
12
|
-
- NDFlip
|
|
13
|
-
|
|
14
|
-
Note that while any Albumentations (see https://albumentations.ai/) transform can be
|
|
15
|
-
used in CAREamics, no check are implemented to verify the compatibility of any other
|
|
16
|
-
transforms than the ones officially supported.
|
|
17
|
-
"""
|
|
7
|
+
"""Transforms officially supported by CAREamics."""
|
|
18
8
|
|
|
19
|
-
|
|
9
|
+
XY_FLIP = "XYFlip"
|
|
20
10
|
XY_RANDOM_ROTATE90 = "XYRandomRotate90"
|
|
21
|
-
NORMALIZE = "Normalize"
|
|
22
11
|
N2V_MANIPULATE = "N2VManipulate"
|
|
23
|
-
# CUSTOM = "Custom"
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
"""Pydantic model representing the metadata of a prediction tile."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel, ConfigDict,
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, field_validator
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class TileInformation(BaseModel):
|
|
@@ -11,30 +11,33 @@ class TileInformation(BaseModel):
|
|
|
11
11
|
|
|
12
12
|
This model is used to represent the information required to stitch back a tile into
|
|
13
13
|
a larger image. It is used throughout the prediction pipeline of CAREamics.
|
|
14
|
+
|
|
15
|
+
Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
|
|
16
|
+
contain singleton dimensions.
|
|
14
17
|
"""
|
|
15
18
|
|
|
16
19
|
model_config = ConfigDict(validate_default=True)
|
|
17
20
|
|
|
18
|
-
array_shape:
|
|
19
|
-
tiled: bool = False
|
|
21
|
+
array_shape: tuple[int, ...]
|
|
20
22
|
last_tile: bool = False
|
|
21
|
-
overlap_crop_coords:
|
|
22
|
-
stitch_coords:
|
|
23
|
+
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
24
|
+
stitch_coords: tuple[tuple[int, ...], ...]
|
|
25
|
+
sample_id: int
|
|
23
26
|
|
|
24
27
|
@field_validator("array_shape")
|
|
25
28
|
@classmethod
|
|
26
|
-
def no_singleton_dimensions(cls, v:
|
|
29
|
+
def no_singleton_dimensions(cls, v: tuple[int, ...]):
|
|
27
30
|
"""
|
|
28
31
|
Check that the array shape does not have any singleton dimensions.
|
|
29
32
|
|
|
30
33
|
Parameters
|
|
31
34
|
----------
|
|
32
|
-
v :
|
|
35
|
+
v : tuple of int
|
|
33
36
|
Array shape to check.
|
|
34
37
|
|
|
35
38
|
Returns
|
|
36
39
|
-------
|
|
37
|
-
|
|
40
|
+
tuple of int
|
|
38
41
|
The array shape if it does not contain singleton dimensions.
|
|
39
42
|
|
|
40
43
|
Raises
|
|
@@ -46,59 +49,26 @@ class TileInformation(BaseModel):
|
|
|
46
49
|
raise ValueError("Array shape must not contain singleton dimensions.")
|
|
47
50
|
return v
|
|
48
51
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def only_if_tiled(cls, v: bool, values: ValidationInfo):
|
|
52
|
-
"""
|
|
53
|
-
Check that the last tile flag is only set if tiling is enabled.
|
|
52
|
+
def __eq__(self, other_tile: object):
|
|
53
|
+
"""Check if two tile information objects are equal.
|
|
54
54
|
|
|
55
55
|
Parameters
|
|
56
56
|
----------
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
values : ValidationInfo
|
|
60
|
-
Validation information.
|
|
57
|
+
other_tile : object
|
|
58
|
+
Tile information object to compare with.
|
|
61
59
|
|
|
62
60
|
Returns
|
|
63
61
|
-------
|
|
64
62
|
bool
|
|
65
|
-
|
|
66
|
-
"""
|
|
67
|
-
if not values.data["tiled"]:
|
|
68
|
-
return False
|
|
69
|
-
return v
|
|
70
|
-
|
|
71
|
-
@field_validator("overlap_crop_coords", "stitch_coords")
|
|
72
|
-
@classmethod
|
|
73
|
-
def mandatory_if_tiled(
|
|
74
|
-
cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
|
|
75
|
-
) -> Optional[Tuple[int, ...]]:
|
|
76
|
-
"""
|
|
77
|
-
Check that the coordinates are not `None` if tiling is enabled.
|
|
78
|
-
|
|
79
|
-
The method also return `None` if tiling is not enabled.
|
|
80
|
-
|
|
81
|
-
Parameters
|
|
82
|
-
----------
|
|
83
|
-
v : Optional[Tuple[int, ...]]
|
|
84
|
-
Coordinates to check.
|
|
85
|
-
values : ValidationInfo
|
|
86
|
-
Validation information.
|
|
87
|
-
|
|
88
|
-
Returns
|
|
89
|
-
-------
|
|
90
|
-
Optional[Tuple[int, ...]]
|
|
91
|
-
The coordinates if tiling is enabled, otherwise `None`.
|
|
92
|
-
|
|
93
|
-
Raises
|
|
94
|
-
------
|
|
95
|
-
ValueError
|
|
96
|
-
If the coordinates are `None` and tiling is enabled.
|
|
63
|
+
Whether the two tile information objects are equal.
|
|
97
64
|
"""
|
|
98
|
-
if
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
65
|
+
if not isinstance(other_tile, TileInformation):
|
|
66
|
+
return NotImplemented
|
|
67
|
+
|
|
68
|
+
return (
|
|
69
|
+
self.array_shape == other_tile.array_shape
|
|
70
|
+
and self.last_tile == other_tile.last_tile
|
|
71
|
+
and self.overlap_crop_coords == other_tile.overlap_crop_coords
|
|
72
|
+
and self.stitch_coords == other_tile.stitch_coords
|
|
73
|
+
and self.sample_id == other_tile.sample_id
|
|
74
|
+
)
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"N2VManipulateModel",
|
|
5
|
-
"
|
|
5
|
+
"XYFlipModel",
|
|
6
6
|
"NormalizeModel",
|
|
7
7
|
"XYRandomRotate90Model",
|
|
8
|
+
"XorYFlipModel",
|
|
8
9
|
]
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
from .n2v_manipulate_model import N2VManipulateModel
|
|
12
|
-
from .nd_flip_model import NDFlipModel
|
|
13
13
|
from .normalize_model import NormalizeModel
|
|
14
|
+
from .xy_flip_model import XYFlipModel
|
|
14
15
|
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Optional
|
|
4
4
|
|
|
5
|
-
from pydantic import ConfigDict, Field
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
6
7
|
|
|
7
8
|
from .transform_model import TransformModel
|
|
8
9
|
|
|
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
|
|
|
28
29
|
)
|
|
29
30
|
|
|
30
31
|
name: Literal["Normalize"] = "Normalize"
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
+
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_means_stds(self: Self) -> Self:
|
|
39
|
+
"""Validate that the means and stds have the same length.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Self
|
|
44
|
+
The instance of the model.
|
|
45
|
+
"""
|
|
46
|
+
if len(self.image_means) != len(self.image_stds):
|
|
47
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
48
|
+
|
|
49
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Both target means and stds must be provided together, or bot None."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
55
|
+
if len(self.target_means) != len(self.target_stds):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"The number of target means and stds must be the same."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Pydantic model for the XYFlip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlipModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent XYFlip transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYFlip"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYFlip"] = "XYFlip"
|
|
29
|
+
flip_x: bool = Field(
|
|
30
|
+
True,
|
|
31
|
+
description="Whether to flip along the X axis.",
|
|
32
|
+
)
|
|
33
|
+
flip_y: bool = Field(
|
|
34
|
+
True,
|
|
35
|
+
description="Whether to flip along the Y axis.",
|
|
36
|
+
)
|
|
37
|
+
p: float = Field(
|
|
38
|
+
0.5,
|
|
39
|
+
description="Probability of applying the transform.",
|
|
40
|
+
ge=0,
|
|
41
|
+
le=1,
|
|
42
|
+
)
|
|
43
|
+
seed: Optional[int] = None
|
|
@@ -2,21 +2,23 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Literal, Optional
|
|
4
4
|
|
|
5
|
-
from pydantic import ConfigDict
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
6
|
|
|
7
7
|
from .transform_model import TransformModel
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class XYRandomRotate90Model(TransformModel):
|
|
11
11
|
"""
|
|
12
|
-
Pydantic model used to represent
|
|
12
|
+
Pydantic model used to represent the XY random 90 degree rotation transformation.
|
|
13
13
|
|
|
14
14
|
Attributes
|
|
15
15
|
----------
|
|
16
16
|
name : Literal["XYRandomRotate90"]
|
|
17
17
|
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
18
20
|
seed : Optional[int]
|
|
19
|
-
Seed for the random number generator.
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
20
22
|
"""
|
|
21
23
|
|
|
22
24
|
model_config = ConfigDict(
|
|
@@ -24,4 +26,10 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
24
26
|
)
|
|
25
27
|
|
|
26
28
|
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
29
|
+
p: float = Field(
|
|
30
|
+
0.5,
|
|
31
|
+
description="Probability of applying the transform.",
|
|
32
|
+
ge=0,
|
|
33
|
+
le=1,
|
|
34
|
+
)
|
|
27
35
|
seed: Optional[int] = None
|
|
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
|
|
|
72
72
|
If the value is not a power of 2.
|
|
73
73
|
"""
|
|
74
74
|
if value < 8:
|
|
75
|
-
raise ValueError(f"Value must be
|
|
75
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
76
76
|
|
|
77
77
|
if (value & (value - 1)) != 0:
|
|
78
78
|
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
careamics/conftest.py
CHANGED
|
@@ -14,6 +14,18 @@ from sybil.parsers.doctest import DocTestParser
|
|
|
14
14
|
|
|
15
15
|
@pytest.fixture(scope="module")
|
|
16
16
|
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
17
|
+
"""Fixture used in doctest to create a temporary directory.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tmpdir_factory : TempPathFactory
|
|
22
|
+
Temporary path factory from pytest.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
Temporary directory path.
|
|
28
|
+
"""
|
|
17
29
|
return tmpdir_factory.mktemp("my_path")
|
|
18
30
|
|
|
19
31
|
|
careamics/dataset/__init__.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
|
1
1
|
"""Dataset module."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"PathIterableDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"IterablePredDataset",
|
|
10
|
+
]
|
|
4
11
|
|
|
5
12
|
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
6
15
|
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -2,17 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"reshape_array",
|
|
5
|
+
"compute_normalization_stats",
|
|
5
6
|
"get_files_size",
|
|
6
7
|
"list_files",
|
|
7
8
|
"validate_source_target_files",
|
|
8
9
|
"read_tiff",
|
|
9
10
|
"get_read_func",
|
|
10
11
|
"read_zarr",
|
|
12
|
+
"iterate_over_files",
|
|
13
|
+
"WelfordStatistics",
|
|
11
14
|
]
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
from .dataset_utils import
|
|
17
|
+
from .dataset_utils import (
|
|
18
|
+
reshape_array,
|
|
19
|
+
)
|
|
15
20
|
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
21
|
+
from .iterate_over_files import iterate_over_files
|
|
16
22
|
from .read_tiff import read_tiff
|
|
17
23
|
from .read_utils import get_read_func
|
|
18
24
|
from .read_zarr import read_zarr
|
|
25
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Dataset utilities."""
|
|
2
2
|
|
|
3
3
|
from typing import List, Tuple
|
|
4
4
|
|
|
@@ -17,12 +17,12 @@ def _get_shape_order(
|
|
|
17
17
|
|
|
18
18
|
Parameters
|
|
19
19
|
----------
|
|
20
|
-
shape_in : Tuple
|
|
20
|
+
shape_in : Tuple[int, ...]
|
|
21
21
|
Input shape.
|
|
22
|
-
ref_axes : str
|
|
23
|
-
Reference axes.
|
|
24
22
|
axes_in : str
|
|
25
23
|
Input axes.
|
|
24
|
+
ref_axes : str
|
|
25
|
+
Reference axes.
|
|
26
26
|
|
|
27
27
|
Returns
|
|
28
28
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""File utilities."""
|
|
2
|
+
|
|
1
3
|
from fnmatch import fnmatch
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
from typing import List, Union
|
|
@@ -11,8 +13,7 @@ logger = get_logger(__name__)
|
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def get_files_size(files: List[Path]) -> float:
|
|
14
|
-
"""
|
|
15
|
-
Get files size in MB.
|
|
16
|
+
"""Get files size in MB.
|
|
16
17
|
|
|
17
18
|
Parameters
|
|
18
19
|
----------
|
|
@@ -32,7 +33,7 @@ def list_files(
|
|
|
32
33
|
data_type: Union[str, SupportedData],
|
|
33
34
|
extension_filter: str = "",
|
|
34
35
|
) -> List[Path]:
|
|
35
|
-
"""
|
|
36
|
+
"""List recursively files in `data_path` and return a sorted list.
|
|
36
37
|
|
|
37
38
|
If `data_path` is a file, its name is validated against the `data_type` using
|
|
38
39
|
`fnmatch`, and the method returns `data_path` itself.
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Function to iterate over files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Generator, Optional, Union
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import get_worker_info
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.utils.logging import get_logger
|
|
13
|
+
|
|
14
|
+
from .dataset_utils import reshape_array
|
|
15
|
+
from .read_tiff import read_tiff
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def iterate_over_files(
|
|
21
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
22
|
+
data_files: list[Path],
|
|
23
|
+
target_files: Optional[list[Path]] = None,
|
|
24
|
+
read_source_func: Callable = read_tiff,
|
|
25
|
+
) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
|
|
26
|
+
"""Iterate over data source and yield whole reshaped images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_config : CAREamics DataConfig or InferenceConfig
|
|
31
|
+
Configuration.
|
|
32
|
+
data_files : list of pathlib.Path
|
|
33
|
+
List of data files.
|
|
34
|
+
target_files : list of pathlib.Path, optional
|
|
35
|
+
List of target files, by default None.
|
|
36
|
+
read_source_func : Callable, optional
|
|
37
|
+
Function to read the source, by default read_tiff.
|
|
38
|
+
|
|
39
|
+
Yields
|
|
40
|
+
------
|
|
41
|
+
NDArray
|
|
42
|
+
Image.
|
|
43
|
+
"""
|
|
44
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
45
|
+
# dataset object
|
|
46
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
47
|
+
# from the workers
|
|
48
|
+
worker_info = get_worker_info()
|
|
49
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
50
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
51
|
+
|
|
52
|
+
# iterate over the files
|
|
53
|
+
for i, filename in enumerate(data_files):
|
|
54
|
+
# retrieve file corresponding to the worker id
|
|
55
|
+
if i % num_workers == worker_id:
|
|
56
|
+
try:
|
|
57
|
+
# read data
|
|
58
|
+
sample = read_source_func(filename, data_config.axes)
|
|
59
|
+
|
|
60
|
+
# reshape array
|
|
61
|
+
reshaped_sample = reshape_array(sample, data_config.axes)
|
|
62
|
+
|
|
63
|
+
# read target, if available
|
|
64
|
+
if target_files is not None:
|
|
65
|
+
if filename.name != target_files[i].name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"File {filename} does not match target file "
|
|
68
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
69
|
+
f"arrays?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# read target
|
|
73
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
74
|
+
|
|
75
|
+
# reshape target
|
|
76
|
+
reshaped_target = reshape_array(target, data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield reshaped_sample, reshaped_target
|
|
79
|
+
else:
|
|
80
|
+
yield reshaped_sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Funtions to read tiff images."""
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from fnmatch import fnmatch
|
|
3
5
|
from pathlib import Path
|
|
@@ -19,8 +21,10 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
19
21
|
----------
|
|
20
22
|
file_path : Path
|
|
21
23
|
Path to a file.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
*args : list
|
|
25
|
+
Additional arguments.
|
|
26
|
+
**kwargs : dict
|
|
27
|
+
Additional keyword arguments.
|
|
24
28
|
|
|
25
29
|
Returns
|
|
26
30
|
-------
|
|
@@ -49,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
|
49
53
|
else:
|
|
50
54
|
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
51
55
|
|
|
52
|
-
# check dimensions
|
|
53
|
-
# TODO or should this really be done here? probably in the LightningDataModule
|
|
54
|
-
# TODO this should also be centralized somewhere else (validate_dimensions)
|
|
55
|
-
if len(array.shape) < 2 or len(array.shape) > 6:
|
|
56
|
-
raise ValueError(
|
|
57
|
-
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
|
|
58
|
-
f"file {file_path})."
|
|
59
|
-
)
|
|
60
|
-
|
|
61
56
|
return array
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Function to read zarr images."""
|
|
2
|
+
|
|
1
3
|
from typing import Union
|
|
2
4
|
|
|
3
5
|
from zarr import Group, core, hierarchy, storage
|
|
@@ -6,26 +8,28 @@ from zarr import Group, core, hierarchy, storage
|
|
|
6
8
|
def read_zarr(
|
|
7
9
|
zarr_source: Group, axes: str
|
|
8
10
|
) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
|
|
9
|
-
"""
|
|
11
|
+
"""Read a file and returns a pointer.
|
|
10
12
|
|
|
11
13
|
Parameters
|
|
12
14
|
----------
|
|
13
|
-
|
|
14
|
-
|
|
15
|
+
zarr_source : Group
|
|
16
|
+
Zarr storage.
|
|
17
|
+
axes : str
|
|
18
|
+
Axes of the data.
|
|
15
19
|
|
|
16
20
|
Returns
|
|
17
21
|
-------
|
|
18
22
|
np.ndarray
|
|
19
|
-
Pointer to zarr storage
|
|
23
|
+
Pointer to zarr storage.
|
|
20
24
|
|
|
21
25
|
Raises
|
|
22
26
|
------
|
|
23
27
|
ValueError, OSError
|
|
24
|
-
if a file is not a valid tiff or damaged
|
|
28
|
+
if a file is not a valid tiff or damaged.
|
|
25
29
|
ValueError
|
|
26
|
-
if data dimensions are not 2, 3 or 4
|
|
30
|
+
if data dimensions are not 2, 3 or 4.
|
|
27
31
|
ValueError
|
|
28
|
-
if axes parameter from config is not consistent with data dimensions
|
|
32
|
+
if axes parameter from config is not consistent with data dimensions.
|
|
29
33
|
"""
|
|
30
34
|
if isinstance(zarr_source, hierarchy.Group):
|
|
31
35
|
array = zarr_source[0]
|