careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from fnmatch import fnmatch
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
@@ -12,12 +12,12 @@ from careamics.utils.logging import get_logger
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def get_files_size(files:
|
|
15
|
+
def get_files_size(files: list[Path]) -> float:
|
|
16
16
|
"""Get files size in MB.
|
|
17
17
|
|
|
18
18
|
Parameters
|
|
19
19
|
----------
|
|
20
|
-
files :
|
|
20
|
+
files : list of pathlib.Path
|
|
21
21
|
List of files.
|
|
22
22
|
|
|
23
23
|
Returns
|
|
@@ -32,7 +32,7 @@ def list_files(
|
|
|
32
32
|
data_path: Union[str, Path],
|
|
33
33
|
data_type: Union[str, SupportedData],
|
|
34
34
|
extension_filter: str = "",
|
|
35
|
-
) ->
|
|
35
|
+
) -> list[Path]:
|
|
36
36
|
"""List recursively files in `data_path` and return a sorted list.
|
|
37
37
|
|
|
38
38
|
If `data_path` is a file, its name is validated against the `data_type` using
|
|
@@ -55,8 +55,8 @@ def list_files(
|
|
|
55
55
|
|
|
56
56
|
Returns
|
|
57
57
|
-------
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
list[Path]
|
|
59
|
+
list of pathlib.Path objects.
|
|
60
60
|
|
|
61
61
|
Raises
|
|
62
62
|
------
|
|
@@ -105,7 +105,7 @@ def list_files(
|
|
|
105
105
|
return files
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def validate_source_target_files(src_files:
|
|
108
|
+
def validate_source_target_files(src_files: list[Path], tar_files: list[Path]) -> None:
|
|
109
109
|
"""
|
|
110
110
|
Validate source and target path lists.
|
|
111
111
|
|
|
@@ -113,9 +113,9 @@ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -
|
|
|
113
113
|
|
|
114
114
|
Parameters
|
|
115
115
|
----------
|
|
116
|
-
src_files :
|
|
116
|
+
src_files : list of pathlib.Path
|
|
117
117
|
List of source files.
|
|
118
|
-
tar_files :
|
|
118
|
+
tar_files : list of pathlib.Path
|
|
119
119
|
List of target files.
|
|
120
120
|
|
|
121
121
|
Raises
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Generator
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Callable,
|
|
7
|
+
from typing import Callable, Optional, Union
|
|
7
8
|
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
from torch.utils.data import get_worker_info
|
|
10
11
|
|
|
11
|
-
from careamics.config import
|
|
12
|
+
from careamics.config import GeneralDataConfig, InferenceConfig
|
|
12
13
|
from careamics.file_io.read import read_tiff
|
|
13
14
|
from careamics.utils.logging import get_logger
|
|
14
15
|
|
|
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
|
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def iterate_over_files(
|
|
21
|
-
data_config: Union[
|
|
22
|
+
data_config: Union[GeneralDataConfig, InferenceConfig],
|
|
22
23
|
data_files: list[Path],
|
|
23
24
|
target_files: Optional[list[Path]] = None,
|
|
24
25
|
read_source_func: Callable = read_tiff,
|
|
@@ -34,36 +34,35 @@ def update_iterative_stats(
|
|
|
34
34
|
Parameters
|
|
35
35
|
----------
|
|
36
36
|
count : NDArray
|
|
37
|
-
Number of elements in the array.
|
|
37
|
+
Number of elements in the array. Shape: (C,).
|
|
38
38
|
mean : NDArray
|
|
39
|
-
Mean of the array.
|
|
39
|
+
Mean of the array. Shape: (C,).
|
|
40
40
|
m2 : NDArray
|
|
41
|
-
Variance of the array.
|
|
41
|
+
Variance of the array. Shape: (C,).
|
|
42
42
|
new_values : NDArray
|
|
43
|
-
New values to add to the mean and variance.
|
|
43
|
+
New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
|
|
44
44
|
|
|
45
45
|
Returns
|
|
46
46
|
-------
|
|
47
47
|
tuple[NDArray, NDArray, NDArray]
|
|
48
48
|
Updated count, mean, and variance.
|
|
49
49
|
"""
|
|
50
|
-
|
|
51
|
-
# newvalues - oldMean
|
|
52
|
-
delta = [
|
|
53
|
-
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
-
for v, m in zip(new_values, mean)
|
|
55
|
-
]
|
|
50
|
+
num_channels = len(new_values)
|
|
56
51
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
delta2 = [
|
|
60
|
-
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
-
for v, m in zip(new_values, mean)
|
|
62
|
-
]
|
|
52
|
+
# --- update channel-wise counts ---
|
|
53
|
+
count += np.ones_like(count) * np.prod(new_values.shape[1:])
|
|
63
54
|
|
|
64
|
-
|
|
55
|
+
# --- update channel-wise mean ---
|
|
56
|
+
# compute (new_values - old_mean) -> shape: (C, Z*Y*X)
|
|
57
|
+
delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
|
|
58
|
+
mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)
|
|
65
59
|
|
|
66
|
-
|
|
60
|
+
# --- update channel-wise SoS ---
|
|
61
|
+
# compute (new_values - new_mean) -> shape: (C, Z*Y*X)
|
|
62
|
+
delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
|
|
63
|
+
m2 += np.sum(delta * delta2, axis=1)
|
|
64
|
+
|
|
65
|
+
return count, mean, m2
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
def finalize_iterative_stats(
|
|
@@ -74,18 +73,18 @@ def finalize_iterative_stats(
|
|
|
74
73
|
Parameters
|
|
75
74
|
----------
|
|
76
75
|
count : NDArray
|
|
77
|
-
Number of elements in the array.
|
|
76
|
+
Number of elements in the array. Shape: (C,).
|
|
78
77
|
mean : NDArray
|
|
79
|
-
Mean of the array.
|
|
78
|
+
Mean of the array. Shape: (C,).
|
|
80
79
|
m2 : NDArray
|
|
81
|
-
Variance of the array.
|
|
80
|
+
Variance of the array. Shape: (C,).
|
|
82
81
|
|
|
83
82
|
Returns
|
|
84
83
|
-------
|
|
85
84
|
tuple[NDArray, NDArray]
|
|
86
|
-
Final mean and standard deviation.
|
|
85
|
+
Final channel-wise mean and standard deviation.
|
|
87
86
|
"""
|
|
88
|
-
std = np.
|
|
87
|
+
std = np.sqrt(m2 / count)
|
|
89
88
|
if any(c < 2 for c in count):
|
|
90
89
|
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
90
|
else:
|
|
@@ -9,13 +9,9 @@ from typing import Any, Callable, Optional, Union
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
|
-
from careamics.
|
|
13
|
-
from careamics.
|
|
14
|
-
|
|
15
|
-
from ..config import DataConfig
|
|
16
|
-
from ..config.transformations import NormalizeModel
|
|
17
|
-
from ..utils.logging import get_logger
|
|
18
|
-
from .patching.patching import (
|
|
12
|
+
from careamics.config import GeneralDataConfig, N2VDataConfig
|
|
13
|
+
from careamics.config.transformations import NormalizeModel
|
|
14
|
+
from careamics.dataset.patching.patching import (
|
|
19
15
|
PatchedOutput,
|
|
20
16
|
Stats,
|
|
21
17
|
prepare_patches_supervised,
|
|
@@ -23,6 +19,9 @@ from .patching.patching import (
|
|
|
23
19
|
prepare_patches_unsupervised,
|
|
24
20
|
prepare_patches_unsupervised_array,
|
|
25
21
|
)
|
|
22
|
+
from careamics.file_io.read import read_tiff
|
|
23
|
+
from careamics.transforms import Compose
|
|
24
|
+
from careamics.utils.logging import get_logger
|
|
26
25
|
|
|
27
26
|
logger = get_logger(__name__)
|
|
28
27
|
|
|
@@ -47,7 +46,7 @@ class InMemoryDataset(Dataset):
|
|
|
47
46
|
|
|
48
47
|
def __init__(
|
|
49
48
|
self,
|
|
50
|
-
data_config:
|
|
49
|
+
data_config: GeneralDataConfig,
|
|
51
50
|
inputs: Union[np.ndarray, list[Path]],
|
|
52
51
|
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
53
52
|
read_source_func: Callable = read_tiff,
|
|
@@ -58,7 +57,7 @@ class InMemoryDataset(Dataset):
|
|
|
58
57
|
|
|
59
58
|
Parameters
|
|
60
59
|
----------
|
|
61
|
-
data_config :
|
|
60
|
+
data_config : GeneralDataConfig
|
|
62
61
|
Data configuration.
|
|
63
62
|
inputs : numpy.ndarray or list[pathlib.Path]
|
|
64
63
|
Input data.
|
|
@@ -124,7 +123,7 @@ class InMemoryDataset(Dataset):
|
|
|
124
123
|
target_stds=self.target_stats.stds,
|
|
125
124
|
)
|
|
126
125
|
]
|
|
127
|
-
+ self.data_config.transforms,
|
|
126
|
+
+ list(self.data_config.transforms),
|
|
128
127
|
)
|
|
129
128
|
|
|
130
129
|
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
@@ -219,12 +218,12 @@ class InMemoryDataset(Dataset):
|
|
|
219
218
|
|
|
220
219
|
return self.patch_transform(patch=patch, target=target)
|
|
221
220
|
|
|
222
|
-
elif self.data_config
|
|
221
|
+
elif isinstance(self.data_config, N2VDataConfig):
|
|
223
222
|
return self.patch_transform(patch=patch)
|
|
224
223
|
else:
|
|
225
224
|
raise ValueError(
|
|
226
225
|
"Something went wrong! No target provided (not supervised training) "
|
|
227
|
-
"
|
|
226
|
+
"while the algorithm is not Noise2Void."
|
|
228
227
|
)
|
|
229
228
|
|
|
230
229
|
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
@@ -10,7 +10,7 @@ from typing import Callable, Optional
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from torch.utils.data import IterableDataset
|
|
12
12
|
|
|
13
|
-
from careamics.config import
|
|
13
|
+
from careamics.config import GeneralDataConfig
|
|
14
14
|
from careamics.config.transformations import NormalizeModel
|
|
15
15
|
from careamics.file_io.read import read_tiff
|
|
16
16
|
from careamics.transforms import Compose
|
|
@@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
49
49
|
|
|
50
50
|
def __init__(
|
|
51
51
|
self,
|
|
52
|
-
data_config:
|
|
52
|
+
data_config: GeneralDataConfig,
|
|
53
53
|
src_files: list[Path],
|
|
54
54
|
target_files: Optional[list[Path]] = None,
|
|
55
55
|
read_source_func: Callable = read_tiff,
|
|
@@ -58,7 +58,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
58
58
|
|
|
59
59
|
Parameters
|
|
60
60
|
----------
|
|
61
|
-
data_config :
|
|
61
|
+
data_config : GeneralDataConfig
|
|
62
62
|
Data configuration.
|
|
63
63
|
src_files : list[Path]
|
|
64
64
|
List of data files.
|
|
@@ -115,7 +115,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
115
115
|
target_stds=self.target_stats.stds,
|
|
116
116
|
)
|
|
117
117
|
]
|
|
118
|
-
+ data_config.transforms
|
|
118
|
+
+ list(data_config.transforms)
|
|
119
119
|
)
|
|
120
120
|
|
|
121
121
|
def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
|
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Generator
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable
|
|
7
|
+
from typing import Any, Callable
|
|
7
8
|
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
from torch.utils.data import IterableDataset
|
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Generator
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable
|
|
7
|
+
from typing import Any, Callable
|
|
7
8
|
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
from torch.utils.data import IterableDataset
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Random patching utilities."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from typing import Optional, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import zarr
|
|
@@ -11,10 +12,10 @@ from .validate_patch_dimension import validate_patch_dimensions
|
|
|
11
12
|
# TOOD split in testable functions
|
|
12
13
|
def extract_patches_random(
|
|
13
14
|
arr: np.ndarray,
|
|
14
|
-
patch_size: Union[
|
|
15
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
15
16
|
target: Optional[np.ndarray] = None,
|
|
16
17
|
seed: Optional[int] = None,
|
|
17
|
-
) -> Generator[
|
|
18
|
+
) -> Generator[tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
18
19
|
"""
|
|
19
20
|
Generate patches from an array in a random manner.
|
|
20
21
|
|
|
@@ -31,12 +32,12 @@ def extract_patches_random(
|
|
|
31
32
|
----------
|
|
32
33
|
arr : np.ndarray
|
|
33
34
|
Input image array.
|
|
34
|
-
patch_size :
|
|
35
|
+
patch_size : tuple of int
|
|
35
36
|
Patch sizes in each dimension.
|
|
36
37
|
target : Optional[np.ndarray], optional
|
|
37
38
|
Target array, by default None.
|
|
38
|
-
seed :
|
|
39
|
-
Random seed
|
|
39
|
+
seed : int or None, default=None
|
|
40
|
+
Random seed.
|
|
40
41
|
|
|
41
42
|
Yields
|
|
42
43
|
------
|
|
@@ -112,8 +113,8 @@ def extract_patches_random(
|
|
|
112
113
|
|
|
113
114
|
def extract_patches_random_from_chunks(
|
|
114
115
|
arr: zarr.Array,
|
|
115
|
-
patch_size: Union[
|
|
116
|
-
chunk_size: Union[
|
|
116
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
117
|
+
chunk_size: Union[list[int], tuple[int, ...]],
|
|
117
118
|
chunk_limit: Optional[int] = None,
|
|
118
119
|
seed: Optional[int] = None,
|
|
119
120
|
) -> Generator[np.ndarray, None, None]:
|
|
@@ -127,9 +128,9 @@ def extract_patches_random_from_chunks(
|
|
|
127
128
|
----------
|
|
128
129
|
arr : np.ndarray
|
|
129
130
|
Input image array.
|
|
130
|
-
patch_size : Union[
|
|
131
|
+
patch_size : Union[list[int], tuple[int, ...]]
|
|
131
132
|
Patch sizes in each dimension.
|
|
132
|
-
chunk_size : Union[
|
|
133
|
+
chunk_size : Union[list[int], tuple[int, ...]]
|
|
133
134
|
Chunk sizes to load from the.
|
|
134
135
|
chunk_limit : Optional[int], optional
|
|
135
136
|
Number of chunks to load, by default None.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Sequential patching functions."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from skimage.util import view_as_windows
|
|
@@ -9,21 +9,21 @@ from .validate_patch_dimension import validate_patch_dimensions
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def _compute_number_of_patches(
|
|
12
|
-
arr_shape:
|
|
13
|
-
) ->
|
|
12
|
+
arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
|
|
13
|
+
) -> tuple[int, ...]:
|
|
14
14
|
"""
|
|
15
15
|
Compute the number of patches that fit in each dimension.
|
|
16
16
|
|
|
17
17
|
Parameters
|
|
18
18
|
----------
|
|
19
|
-
arr_shape :
|
|
19
|
+
arr_shape : tuple[int, ...]
|
|
20
20
|
Shape of the input array.
|
|
21
|
-
patch_sizes : Union[
|
|
21
|
+
patch_sizes : Union[list[int], tuple[int, ...]
|
|
22
22
|
Shape of the patches.
|
|
23
23
|
|
|
24
24
|
Returns
|
|
25
25
|
-------
|
|
26
|
-
|
|
26
|
+
tuple[int, ...]
|
|
27
27
|
Number of patches in each dimension.
|
|
28
28
|
"""
|
|
29
29
|
if len(arr_shape) != len(patch_sizes):
|
|
@@ -47,8 +47,8 @@ def _compute_number_of_patches(
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _compute_overlap(
|
|
50
|
-
arr_shape:
|
|
51
|
-
) ->
|
|
50
|
+
arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
|
|
51
|
+
) -> tuple[int, ...]:
|
|
52
52
|
"""
|
|
53
53
|
Compute the overlap between patches in each dimension.
|
|
54
54
|
|
|
@@ -57,14 +57,14 @@ def _compute_overlap(
|
|
|
57
57
|
|
|
58
58
|
Parameters
|
|
59
59
|
----------
|
|
60
|
-
arr_shape :
|
|
60
|
+
arr_shape : tuple[int, ...]
|
|
61
61
|
Input array shape.
|
|
62
|
-
patch_sizes : Union[
|
|
62
|
+
patch_sizes : Union[list[int], tuple[int, ...]]
|
|
63
63
|
Size of the patches.
|
|
64
64
|
|
|
65
65
|
Returns
|
|
66
66
|
-------
|
|
67
|
-
|
|
67
|
+
tuple[int, ...]
|
|
68
68
|
Overlap between patches in each dimension.
|
|
69
69
|
"""
|
|
70
70
|
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
@@ -80,21 +80,21 @@ def _compute_overlap(
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def _compute_patch_steps(
|
|
83
|
-
patch_sizes: Union[
|
|
84
|
-
) ->
|
|
83
|
+
patch_sizes: Union[list[int], tuple[int, ...]], overlaps: tuple[int, ...]
|
|
84
|
+
) -> tuple[int, ...]:
|
|
85
85
|
"""
|
|
86
86
|
Compute steps between patches.
|
|
87
87
|
|
|
88
88
|
Parameters
|
|
89
89
|
----------
|
|
90
|
-
patch_sizes :
|
|
90
|
+
patch_sizes : tuple[int]
|
|
91
91
|
Size of the patches.
|
|
92
|
-
overlaps :
|
|
92
|
+
overlaps : tuple[int]
|
|
93
93
|
Overlap between patches.
|
|
94
94
|
|
|
95
95
|
Returns
|
|
96
96
|
-------
|
|
97
|
-
|
|
97
|
+
tuple[int]
|
|
98
98
|
Steps between patches.
|
|
99
99
|
"""
|
|
100
100
|
steps = [
|
|
@@ -107,9 +107,9 @@ def _compute_patch_steps(
|
|
|
107
107
|
# TODO why stack the target here and not on a different dimension before this function?
|
|
108
108
|
def _compute_patch_views(
|
|
109
109
|
arr: np.ndarray,
|
|
110
|
-
window_shape:
|
|
111
|
-
step:
|
|
112
|
-
output_shape:
|
|
110
|
+
window_shape: list[int],
|
|
111
|
+
step: tuple[int, ...],
|
|
112
|
+
output_shape: list[int],
|
|
113
113
|
target: Optional[np.ndarray] = None,
|
|
114
114
|
) -> np.ndarray:
|
|
115
115
|
"""
|
|
@@ -119,11 +119,11 @@ def _compute_patch_views(
|
|
|
119
119
|
----------
|
|
120
120
|
arr : np.ndarray
|
|
121
121
|
Array from which the views are extracted.
|
|
122
|
-
window_shape :
|
|
122
|
+
window_shape : tuple[int]
|
|
123
123
|
Shape of the views.
|
|
124
|
-
step :
|
|
124
|
+
step : tuple[int]
|
|
125
125
|
Steps between views.
|
|
126
|
-
output_shape :
|
|
126
|
+
output_shape : tuple[int]
|
|
127
127
|
Shape of the output array.
|
|
128
128
|
target : Optional[np.ndarray], optional
|
|
129
129
|
Target array, by default None.
|
|
@@ -150,9 +150,9 @@ def _compute_patch_views(
|
|
|
150
150
|
|
|
151
151
|
def extract_patches_sequential(
|
|
152
152
|
arr: np.ndarray,
|
|
153
|
-
patch_size: Union[
|
|
153
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
154
154
|
target: Optional[np.ndarray] = None,
|
|
155
|
-
) ->
|
|
155
|
+
) -> tuple[np.ndarray, Optional[np.ndarray]]:
|
|
156
156
|
"""
|
|
157
157
|
Generate patches from an array in a sequential manner.
|
|
158
158
|
|
|
@@ -163,14 +163,14 @@ def extract_patches_sequential(
|
|
|
163
163
|
----------
|
|
164
164
|
arr : np.ndarray
|
|
165
165
|
Input image array.
|
|
166
|
-
patch_size :
|
|
166
|
+
patch_size : tuple[int]
|
|
167
167
|
Patch sizes in each dimension.
|
|
168
168
|
target : Optional[np.ndarray], optional
|
|
169
169
|
Target array, by default None.
|
|
170
170
|
|
|
171
171
|
Returns
|
|
172
172
|
-------
|
|
173
|
-
|
|
173
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
174
174
|
Patches.
|
|
175
175
|
"""
|
|
176
176
|
is_3d_patch = len(patch_size) == 3
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Patch validation functions."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def validate_patch_dimensions(
|
|
9
9
|
arr: np.ndarray,
|
|
10
|
-
patch_size: Union[
|
|
10
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
11
11
|
is_3d_patch: bool,
|
|
12
12
|
) -> None:
|
|
13
13
|
"""
|
|
@@ -26,7 +26,7 @@ def validate_patch_dimensions(
|
|
|
26
26
|
----------
|
|
27
27
|
arr : np.ndarray
|
|
28
28
|
Input array.
|
|
29
|
-
patch_size : Union[
|
|
29
|
+
patch_size : Union[list[int], tuple[int, ...]]
|
|
30
30
|
Size of the patches along each dimension of the array, except the first.
|
|
31
31
|
is_3d_patch : bool
|
|
32
32
|
Whether the patch is 3D or not.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Collate function for tiling."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from torch.utils.data.dataloader import default_collate
|
|
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
|
|
|
8
8
|
from careamics.config.tile_information import TileInformation
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def collate_tiles(batch:
|
|
11
|
+
def collate_tiles(batch: list[tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
12
|
"""
|
|
13
13
|
Collate tiles received from CAREamics prediction dataloader.
|
|
14
14
|
|
|
@@ -19,7 +19,7 @@ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
|
19
19
|
|
|
20
20
|
Parameters
|
|
21
21
|
----------
|
|
22
|
-
batch :
|
|
22
|
+
batch : list[tuple[np.ndarray, TileInformation], ...]
|
|
23
23
|
Batch of tiles.
|
|
24
24
|
|
|
25
25
|
Returns
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Tiled patching utilities."""
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
|
-
from
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from typing import Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
@@ -10,7 +11,7 @@ from careamics.config.tile_information import TileInformation
|
|
|
10
11
|
|
|
11
12
|
def _compute_crop_and_stitch_coords_1d(
|
|
12
13
|
axis_size: int, tile_size: int, overlap: int
|
|
13
|
-
) ->
|
|
14
|
+
) -> tuple[list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]]]:
|
|
14
15
|
"""
|
|
15
16
|
Compute the coordinates of each tile along an axis, given the overlap.
|
|
16
17
|
|
|
@@ -25,8 +26,8 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
25
26
|
|
|
26
27
|
Returns
|
|
27
28
|
-------
|
|
28
|
-
|
|
29
|
-
|
|
29
|
+
tuple[tuple[int, ...], ...]
|
|
30
|
+
tuple of all coordinates for given axis.
|
|
30
31
|
"""
|
|
31
32
|
# Compute the step between tiles
|
|
32
33
|
step = tile_size - overlap
|
|
@@ -81,9 +82,9 @@ def _compute_crop_and_stitch_coords_1d(
|
|
|
81
82
|
|
|
82
83
|
def extract_tiles(
|
|
83
84
|
arr: np.ndarray,
|
|
84
|
-
tile_size: Union[
|
|
85
|
-
overlaps: Union[
|
|
86
|
-
) -> Generator[
|
|
85
|
+
tile_size: Union[list[int], tuple[int, ...]],
|
|
86
|
+
overlaps: Union[list[int], tuple[int, ...]],
|
|
87
|
+
) -> Generator[tuple[np.ndarray, TileInformation], None, None]:
|
|
87
88
|
"""Generate tiles from the input array with specified overlap.
|
|
88
89
|
|
|
89
90
|
The tiles cover the whole array. The method returns a generator that yields
|
|
@@ -98,14 +99,14 @@ def extract_tiles(
|
|
|
98
99
|
----------
|
|
99
100
|
arr : np.ndarray
|
|
100
101
|
Array of shape (S, C, (Z), Y, X).
|
|
101
|
-
tile_size : Union[
|
|
102
|
+
tile_size : Union[list[int], tuple[int]]
|
|
102
103
|
Tile sizes in each dimension, of length 2 or 3.
|
|
103
|
-
overlaps : Union[
|
|
104
|
+
overlaps : Union[list[int], tuple[int]]
|
|
104
105
|
Overlap values in each dimension, of length 2 or 3.
|
|
105
106
|
|
|
106
107
|
Yields
|
|
107
108
|
------
|
|
108
|
-
Generator[
|
|
109
|
+
Generator[tuple[np.ndarray, TileInformation], None, None]
|
|
109
110
|
Tile generator, yields the tile and additional information.
|
|
110
111
|
"""
|
|
111
112
|
# Iterate over num samples (S)
|
careamics/file_io/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Functions relating reading and writing image files."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"read",
|
|
5
|
-
"write",
|
|
6
|
-
"get_read_func",
|
|
7
|
-
"get_write_func",
|
|
8
4
|
"ReadFunc",
|
|
9
|
-
"WriteFunc",
|
|
10
5
|
"SupportedWriteType",
|
|
6
|
+
"WriteFunc",
|
|
7
|
+
"get_read_func",
|
|
8
|
+
"get_write_func",
|
|
9
|
+
"read",
|
|
10
|
+
"write",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
from . import read, write
|