careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYRandomRotate90Model(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent the XY random 90 degree rotation transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYRandomRotate90"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
29
|
+
p: float = Field(
|
|
30
|
+
0.5,
|
|
31
|
+
description="Probability of applying the transform.",
|
|
32
|
+
ge=0,
|
|
33
|
+
le=1,
|
|
34
|
+
)
|
|
35
|
+
seed: Optional[int] = None
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validator functions.
|
|
3
|
+
|
|
4
|
+
These functions are used to validate dimensions and axes of inputs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
_AXES = "STCZYX"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def check_axes_validity(axes: str) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Sanity check on axes.
|
|
15
|
+
|
|
16
|
+
The constraints on the axes are the following:
|
|
17
|
+
- must be a combination of 'STCZYX'
|
|
18
|
+
- must not contain duplicates
|
|
19
|
+
- must contain at least 2 contiguous axes: X and Y
|
|
20
|
+
- must contain at most 4 axes
|
|
21
|
+
- cannot contain both S and T axes
|
|
22
|
+
|
|
23
|
+
Axes do not need to be in the order 'STCZYX', as this depends on the user data.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
axes : str
|
|
28
|
+
Axes to validate.
|
|
29
|
+
"""
|
|
30
|
+
_axes = axes.upper()
|
|
31
|
+
|
|
32
|
+
# Minimum is 2 (XY) and maximum is 4 (TZYX)
|
|
33
|
+
if len(_axes) < 2 or len(_axes) > 6:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if "YX" not in _axes and "XY" not in _axes:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# all characters must be in REF_AXES = 'STCZYX'
|
|
44
|
+
if not all(s in _AXES for s in _axes):
|
|
45
|
+
raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
|
|
46
|
+
|
|
47
|
+
# check for repeating characters
|
|
48
|
+
for i, s in enumerate(_axes):
|
|
49
|
+
if i != _axes.rfind(s):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid axes {axes}. Cannot contain duplicate axes"
|
|
52
|
+
f" (got multiple {axes[i]})."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def value_ge_than_8_power_of_2(
|
|
57
|
+
value: int,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Validate that the value is greater or equal than 8 and a power of 2.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
value : int
|
|
65
|
+
Value to validate.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If the value is smaller than 8.
|
|
71
|
+
ValueError
|
|
72
|
+
If the value is not a power of 2.
|
|
73
|
+
"""
|
|
74
|
+
if value < 8:
|
|
75
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
76
|
+
|
|
77
|
+
if (value & (value - 1)) != 0:
|
|
78
|
+
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def patch_size_ge_than_8_power_of_2(
|
|
82
|
+
patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
|
|
83
|
+
) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
patch_list : Optional[Union[List[int]]]
|
|
90
|
+
Patch size.
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
ValueError
|
|
95
|
+
If the patch size if smaller than 8.
|
|
96
|
+
ValueError
|
|
97
|
+
If the patch size is not a power of 2.
|
|
98
|
+
"""
|
|
99
|
+
if patch_list is not None:
|
|
100
|
+
for dim in patch_list:
|
|
101
|
+
value_ge_than_8_power_of_2(dim)
|
careamics/conftest.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""File used to discover python modules and run doctest.
|
|
2
|
+
|
|
3
|
+
See https://sybil.readthedocs.io/en/latest/use.html#pytest
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from pytest import TempPathFactory
|
|
10
|
+
from sybil import Sybil
|
|
11
|
+
from sybil.parsers.codeblock import PythonCodeBlockParser
|
|
12
|
+
from sybil.parsers.doctest import DocTestParser
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture(scope="module")
|
|
16
|
+
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
17
|
+
"""Fixture used in doctest to create a temporary directory.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tmpdir_factory : TempPathFactory
|
|
22
|
+
Temporary path factory from pytest.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
Temporary directory path.
|
|
28
|
+
"""
|
|
29
|
+
return tmpdir_factory.mktemp("my_path")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
pytest_collect_file = Sybil(
|
|
33
|
+
parsers=[
|
|
34
|
+
DocTestParser(),
|
|
35
|
+
PythonCodeBlockParser(future_imports=["print_function"]),
|
|
36
|
+
],
|
|
37
|
+
pattern="*.py",
|
|
38
|
+
fixtures=["my_path"],
|
|
39
|
+
).pytest()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Dataset module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"PathIterableDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"IterablePredDataset",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
15
|
+
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Files and arrays utils used in the datasets."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"reshape_array",
|
|
5
|
+
"compute_normalization_stats",
|
|
6
|
+
"get_files_size",
|
|
7
|
+
"list_files",
|
|
8
|
+
"validate_source_target_files",
|
|
9
|
+
"iterate_over_files",
|
|
10
|
+
"WelfordStatistics",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .dataset_utils import (
|
|
15
|
+
reshape_array,
|
|
16
|
+
)
|
|
17
|
+
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
18
|
+
from .iterate_over_files import iterate_over_files
|
|
19
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Dataset utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.utils.logging import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_shape_order(
|
|
13
|
+
shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
|
|
14
|
+
) -> Tuple[Tuple[int, ...], str, List[int]]:
|
|
15
|
+
"""
|
|
16
|
+
Compute a new shape for the array based on the reference axes.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
shape_in : Tuple[int, ...]
|
|
21
|
+
Input shape.
|
|
22
|
+
axes_in : str
|
|
23
|
+
Input axes.
|
|
24
|
+
ref_axes : str
|
|
25
|
+
Reference axes.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
Tuple[Tuple[int, ...], str, List[int]]
|
|
30
|
+
New shape, new axes, indices of axes in the new axes order.
|
|
31
|
+
"""
|
|
32
|
+
indices = [axes_in.find(k) for k in ref_axes]
|
|
33
|
+
|
|
34
|
+
# remove all non-existing axes (index == -1)
|
|
35
|
+
new_indices = list(filter(lambda k: k != -1, indices))
|
|
36
|
+
|
|
37
|
+
# find axes order and get new shape
|
|
38
|
+
new_axes = [axes_in[ind] for ind in new_indices]
|
|
39
|
+
new_shape = tuple([shape_in[ind] for ind in new_indices])
|
|
40
|
+
|
|
41
|
+
return new_shape, "".join(new_axes), new_indices
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
|
|
45
|
+
"""Reshape the data to (S, C, (Z), Y, X) by moving axes.
|
|
46
|
+
|
|
47
|
+
If the data has both S and T axes, the two axes will be merged. A singleton
|
|
48
|
+
dimension is added if there are no C axis.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
x : np.ndarray
|
|
53
|
+
Input array.
|
|
54
|
+
axes : str
|
|
55
|
+
Description of axes in format `STCZYX`.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
np.ndarray
|
|
60
|
+
Reshaped array with shape (S, C, (Z), Y, X).
|
|
61
|
+
"""
|
|
62
|
+
_x = x
|
|
63
|
+
_axes = axes
|
|
64
|
+
|
|
65
|
+
# sanity checks
|
|
66
|
+
if len(_axes) != len(_x.shape):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
|
|
69
|
+
f"correct?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# get new x shape
|
|
73
|
+
new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
|
|
74
|
+
|
|
75
|
+
# if S is not in the list of axes, then add a singleton S
|
|
76
|
+
if "S" not in new_axes:
|
|
77
|
+
new_axes = "S" + new_axes
|
|
78
|
+
_x = _x[np.newaxis, ...]
|
|
79
|
+
new_x_shape = (1,) + new_x_shape
|
|
80
|
+
|
|
81
|
+
# need to change the array of indices
|
|
82
|
+
indices = [0] + [1 + i for i in indices]
|
|
83
|
+
|
|
84
|
+
# reshape by moving axes
|
|
85
|
+
destination = list(range(len(indices)))
|
|
86
|
+
_x = np.moveaxis(_x, indices, destination)
|
|
87
|
+
|
|
88
|
+
# remove T if necessary
|
|
89
|
+
if "T" in new_axes:
|
|
90
|
+
new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
|
|
91
|
+
new_axes = new_axes.replace("T", "")
|
|
92
|
+
|
|
93
|
+
# reshape S and T together
|
|
94
|
+
_x = _x.reshape(new_x_shape)
|
|
95
|
+
|
|
96
|
+
# add channel
|
|
97
|
+
if "C" not in new_axes:
|
|
98
|
+
# Add channel axis after S
|
|
99
|
+
_x = np.expand_dims(_x, new_axes.index("S") + 1)
|
|
100
|
+
|
|
101
|
+
return _x
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""File utilities."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
from careamics.utils.logging import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_files_size(files: List[Path]) -> float:
|
|
16
|
+
"""Get files size in MB.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
files : List[Path]
|
|
21
|
+
List of files.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
float
|
|
26
|
+
Total size of the files in MB.
|
|
27
|
+
"""
|
|
28
|
+
return np.sum([f.stat().st_size / 1024**2 for f in files])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def list_files(
|
|
32
|
+
data_path: Union[str, Path],
|
|
33
|
+
data_type: Union[str, SupportedData],
|
|
34
|
+
extension_filter: str = "",
|
|
35
|
+
) -> List[Path]:
|
|
36
|
+
"""List recursively files in `data_path` and return a sorted list.
|
|
37
|
+
|
|
38
|
+
If `data_path` is a file, its name is validated against the `data_type` using
|
|
39
|
+
`fnmatch`, and the method returns `data_path` itself.
|
|
40
|
+
|
|
41
|
+
By default, if `data_type` is equal to `custom`, all files will be listed. To
|
|
42
|
+
further filter the files, use `extension_filter`.
|
|
43
|
+
|
|
44
|
+
`extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
|
|
45
|
+
or "*.czi".
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
data_path : Union[str, Path]
|
|
50
|
+
Path to the folder containing the data.
|
|
51
|
+
data_type : Union[str, SupportedData]
|
|
52
|
+
One of the supported data type (e.g. tif, custom).
|
|
53
|
+
extension_filter : str, optional
|
|
54
|
+
Extension filter, by default "".
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
List[Path]
|
|
59
|
+
List of pathlib.Path objects.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
FileNotFoundError
|
|
64
|
+
If the data path does not exist.
|
|
65
|
+
ValueError
|
|
66
|
+
If the data path is empty or no files with the extension were found.
|
|
67
|
+
ValueError
|
|
68
|
+
If the file does not match the requested extension.
|
|
69
|
+
"""
|
|
70
|
+
# convert to Path
|
|
71
|
+
data_path = Path(data_path)
|
|
72
|
+
|
|
73
|
+
# raise error if does not exists
|
|
74
|
+
if not data_path.exists():
|
|
75
|
+
raise FileNotFoundError(f"Data path {data_path} does not exist.")
|
|
76
|
+
|
|
77
|
+
# get extension compatible with fnmatch and rglob search
|
|
78
|
+
extension = SupportedData.get_extension_pattern(data_type)
|
|
79
|
+
|
|
80
|
+
if data_type == SupportedData.CUSTOM and extension_filter != "":
|
|
81
|
+
extension = extension_filter
|
|
82
|
+
|
|
83
|
+
# search recurively
|
|
84
|
+
if data_path.is_dir():
|
|
85
|
+
# search recursively the path for files with the extension
|
|
86
|
+
files = sorted(data_path.rglob(extension))
|
|
87
|
+
else:
|
|
88
|
+
# raise error if it has the wrong extension
|
|
89
|
+
if not fnmatch(str(data_path.absolute()), extension):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"File {data_path} does not match the requested extension "
|
|
92
|
+
f'"{extension}".'
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# save in list
|
|
96
|
+
files = [data_path]
|
|
97
|
+
|
|
98
|
+
# raise error if no files were found
|
|
99
|
+
if len(files) == 0:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f'Data path {data_path} is empty or files with extension "{extension}" '
|
|
102
|
+
f"were not found."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return files
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Validate source and target path lists.
|
|
111
|
+
|
|
112
|
+
The two lists should have the same number of files, and the filenames should match.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
src_files : List[Path]
|
|
117
|
+
List of source files.
|
|
118
|
+
tar_files : List[Path]
|
|
119
|
+
List of target files.
|
|
120
|
+
|
|
121
|
+
Raises
|
|
122
|
+
------
|
|
123
|
+
ValueError
|
|
124
|
+
If the number of files in source and target folders is not the same.
|
|
125
|
+
ValueError
|
|
126
|
+
If some filenames in Train and target folders are not the same.
|
|
127
|
+
"""
|
|
128
|
+
# check equal length
|
|
129
|
+
if len(src_files) != len(tar_files):
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"The number of source files ({len(src_files)}) is not equal to the number "
|
|
132
|
+
f"of target files ({len(tar_files)})."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# check identical names
|
|
136
|
+
src_names = {f.name for f in src_files}
|
|
137
|
+
tar_names = {f.name for f in tar_files}
|
|
138
|
+
difference = src_names.symmetric_difference(tar_names)
|
|
139
|
+
|
|
140
|
+
if len(difference) > 0:
|
|
141
|
+
raise ValueError(f"Source and target files have different names: {difference}.")
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Function to iterate over files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Generator, Optional, Union
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import get_worker_info
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
13
|
+
from careamics.utils.logging import get_logger
|
|
14
|
+
|
|
15
|
+
from .dataset_utils import reshape_array
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def iterate_over_files(
|
|
21
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
22
|
+
data_files: list[Path],
|
|
23
|
+
target_files: Optional[list[Path]] = None,
|
|
24
|
+
read_source_func: Callable = read_tiff,
|
|
25
|
+
) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
|
|
26
|
+
"""Iterate over data source and yield whole reshaped images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_config : CAREamics DataConfig or InferenceConfig
|
|
31
|
+
Configuration.
|
|
32
|
+
data_files : list of pathlib.Path
|
|
33
|
+
List of data files.
|
|
34
|
+
target_files : list of pathlib.Path, optional
|
|
35
|
+
List of target files, by default None.
|
|
36
|
+
read_source_func : Callable, optional
|
|
37
|
+
Function to read the source, by default read_tiff.
|
|
38
|
+
|
|
39
|
+
Yields
|
|
40
|
+
------
|
|
41
|
+
NDArray
|
|
42
|
+
Image.
|
|
43
|
+
"""
|
|
44
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
45
|
+
# dataset object
|
|
46
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
47
|
+
# from the workers
|
|
48
|
+
worker_info = get_worker_info()
|
|
49
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
50
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
51
|
+
|
|
52
|
+
# iterate over the files
|
|
53
|
+
for i, filename in enumerate(data_files):
|
|
54
|
+
# retrieve file corresponding to the worker id
|
|
55
|
+
if i % num_workers == worker_id:
|
|
56
|
+
try:
|
|
57
|
+
# read data
|
|
58
|
+
sample = read_source_func(filename, data_config.axes)
|
|
59
|
+
|
|
60
|
+
# reshape array
|
|
61
|
+
reshaped_sample = reshape_array(sample, data_config.axes)
|
|
62
|
+
|
|
63
|
+
# read target, if available
|
|
64
|
+
if target_files is not None:
|
|
65
|
+
if filename.name != target_files[i].name:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"File {filename} does not match target file "
|
|
68
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
69
|
+
f"arrays?"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# read target
|
|
73
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
74
|
+
|
|
75
|
+
# reshape target
|
|
76
|
+
reshaped_target = reshape_array(target, data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield reshaped_sample, reshaped_target
|
|
79
|
+
else:
|
|
80
|
+
yield reshaped_sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|