careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Patch validation functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def validate_patch_dimensions(
|
|
9
|
+
arr: np.ndarray,
|
|
10
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
11
|
+
is_3d_patch: bool,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Check patch size and array compatibility.
|
|
15
|
+
|
|
16
|
+
This method validates the patch sizes with respect to the array dimensions:
|
|
17
|
+
|
|
18
|
+
- Patch must have two dimensions fewer than the array (S and C).
|
|
19
|
+
- Patch sizes are smaller than the corresponding array dimensions.
|
|
20
|
+
|
|
21
|
+
If one of these conditions is not met, a ValueError is raised.
|
|
22
|
+
|
|
23
|
+
This method should be called after inputs have been resized.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
arr : np.ndarray
|
|
28
|
+
Input array.
|
|
29
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
30
|
+
Size of the patches along each dimension of the array, except the first.
|
|
31
|
+
is_3d_patch : bool
|
|
32
|
+
Whether the patch is 3D or not.
|
|
33
|
+
|
|
34
|
+
Raises
|
|
35
|
+
------
|
|
36
|
+
ValueError
|
|
37
|
+
If the patch size is not consistent with the array shape (one more array
|
|
38
|
+
dimension).
|
|
39
|
+
ValueError
|
|
40
|
+
If the patch size in Z is larger than the array dimension.
|
|
41
|
+
ValueError
|
|
42
|
+
If either of the patch sizes in X or Y is larger than the corresponding array
|
|
43
|
+
dimension.
|
|
44
|
+
"""
|
|
45
|
+
if len(patch_size) != len(arr.shape[2:]):
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"There must be a patch size for each spatial dimensions "
|
|
48
|
+
f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Sanity checks on patch sizes versus array dimension
|
|
52
|
+
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Z patch size is inconsistent with image shape "
|
|
55
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
|
|
56
|
+
f"order."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
62
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
|
|
63
|
+
f"Check the axes order."
|
|
64
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Collate function for tiling."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data.dataloader import default_collate
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
14
|
+
|
|
15
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
16
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
17
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
18
|
+
stitch coordinates.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
23
|
+
Batch of tiles.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Any
|
|
28
|
+
Collated batch.
|
|
29
|
+
"""
|
|
30
|
+
new_batch = [tile for tile, _ in batch]
|
|
31
|
+
tiles_batch = [tile_info for _, tile_info in batch]
|
|
32
|
+
|
|
33
|
+
return default_collate(new_batch), tiles_batch
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Tiled patching utilities."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from typing import Generator, List, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _compute_crop_and_stitch_coords_1d(
|
|
12
|
+
axis_size: int, tile_size: int, overlap: int
|
|
13
|
+
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
|
|
14
|
+
"""
|
|
15
|
+
Compute the coordinates of each tile along an axis, given the overlap.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
axis_size : int
|
|
20
|
+
Length of the axis.
|
|
21
|
+
tile_size : int
|
|
22
|
+
Size of the tile for the given axis.
|
|
23
|
+
overlap : int
|
|
24
|
+
Size of the overlap for the given axis.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Tuple[Tuple[int, ...], ...]
|
|
29
|
+
Tuple of all coordinates for given axis.
|
|
30
|
+
"""
|
|
31
|
+
# Compute the step between tiles
|
|
32
|
+
step = tile_size - overlap
|
|
33
|
+
crop_coords = []
|
|
34
|
+
stitch_coords = []
|
|
35
|
+
overlap_crop_coords = []
|
|
36
|
+
|
|
37
|
+
# Iterate over the axis with step
|
|
38
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
39
|
+
# Check if the tile fits within the axis
|
|
40
|
+
if i + tile_size <= axis_size:
|
|
41
|
+
# Add the coordinates to crop one tile
|
|
42
|
+
crop_coords.append((i, i + tile_size))
|
|
43
|
+
|
|
44
|
+
# Add the pixel coordinates of the cropped tile in the original image space
|
|
45
|
+
stitch_coords.append(
|
|
46
|
+
(
|
|
47
|
+
i + overlap // 2 if i > 0 else 0,
|
|
48
|
+
(
|
|
49
|
+
i + tile_size - overlap // 2
|
|
50
|
+
if crop_coords[-1][1] < axis_size
|
|
51
|
+
else axis_size
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Add the coordinates to crop the overlap from the prediction.
|
|
57
|
+
overlap_crop_coords.append(
|
|
58
|
+
(
|
|
59
|
+
overlap // 2 if i > 0 else 0,
|
|
60
|
+
(
|
|
61
|
+
tile_size - overlap // 2
|
|
62
|
+
if crop_coords[-1][1] < axis_size
|
|
63
|
+
else tile_size
|
|
64
|
+
),
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# If the tile does not fit within the axis, perform the abovementioned
|
|
69
|
+
# operations starting from the end of the axis
|
|
70
|
+
else:
|
|
71
|
+
# if (axis_size - tile_size, axis_size) not in crop_coords:
|
|
72
|
+
crop_coords.append((max(0, axis_size - tile_size), axis_size))
|
|
73
|
+
last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
|
|
74
|
+
stitch_coords.append((last_tile_end_coord, axis_size))
|
|
75
|
+
overlap_crop_coords.append(
|
|
76
|
+
(tile_size - (axis_size - last_tile_end_coord), tile_size)
|
|
77
|
+
)
|
|
78
|
+
break
|
|
79
|
+
return crop_coords, stitch_coords, overlap_crop_coords
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_tiles(
|
|
83
|
+
arr: np.ndarray,
|
|
84
|
+
tile_size: Union[List[int], Tuple[int, ...]],
|
|
85
|
+
overlaps: Union[List[int], Tuple[int, ...]],
|
|
86
|
+
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
87
|
+
"""Generate tiles from the input array with specified overlap.
|
|
88
|
+
|
|
89
|
+
The tiles cover the whole array. The method returns a generator that yields
|
|
90
|
+
tuples of array and tile information, the latter includes whether
|
|
91
|
+
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
92
|
+
of the stitched tile.
|
|
93
|
+
|
|
94
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
95
|
+
where C can be a singleton.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
arr : np.ndarray
|
|
100
|
+
Array of shape (S, C, (Z), Y, X).
|
|
101
|
+
tile_size : Union[List[int], Tuple[int]]
|
|
102
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
103
|
+
overlaps : Union[List[int], Tuple[int]]
|
|
104
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
105
|
+
|
|
106
|
+
Yields
|
|
107
|
+
------
|
|
108
|
+
Generator[Tuple[np.ndarray, TileInformation], None, None]
|
|
109
|
+
Tile generator, yields the tile and additional information.
|
|
110
|
+
"""
|
|
111
|
+
# Iterate over num samples (S)
|
|
112
|
+
for sample_idx in range(arr.shape[0]):
|
|
113
|
+
sample: np.ndarray = arr[sample_idx, ...]
|
|
114
|
+
|
|
115
|
+
# Create a list of coordinates for cropping and stitching all axes.
|
|
116
|
+
# [crop coordinates, stitching coordinates, overlap crop coordinates]
|
|
117
|
+
# For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
|
|
118
|
+
# will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
|
|
119
|
+
crop_and_stitch_coords_list = [
|
|
120
|
+
_compute_crop_and_stitch_coords_1d(
|
|
121
|
+
sample.shape[i + 1], tile_size[i], overlaps[i]
|
|
122
|
+
)
|
|
123
|
+
for i in range(len(tile_size))
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
127
|
+
# grouped by type.
|
|
128
|
+
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
129
|
+
*crop_and_stitch_coords_list
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Maximum tile index
|
|
133
|
+
max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
|
|
134
|
+
|
|
135
|
+
# Iterate over generated coordinate pairs:
|
|
136
|
+
for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
|
|
137
|
+
zip(
|
|
138
|
+
itertools.product(*all_crop_coords),
|
|
139
|
+
itertools.product(*all_stitch_coords),
|
|
140
|
+
itertools.product(*all_overlap_crop_coords),
|
|
141
|
+
)
|
|
142
|
+
):
|
|
143
|
+
# Extract tile from the sample
|
|
144
|
+
tile: np.ndarray = sample[
|
|
145
|
+
(..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
# Check if we are at the end of the sample by computing the length of the
|
|
149
|
+
# array that contains all the tiles
|
|
150
|
+
if tile_idx == max_tile_idx:
|
|
151
|
+
last_tile = True
|
|
152
|
+
else:
|
|
153
|
+
last_tile = False
|
|
154
|
+
|
|
155
|
+
# create tile information
|
|
156
|
+
tile_info = TileInformation(
|
|
157
|
+
array_shape=sample.shape,
|
|
158
|
+
last_tile=last_tile,
|
|
159
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
160
|
+
stitch_coords=stitch_coords,
|
|
161
|
+
sample_id=sample_idx,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
yield tile, tile_info
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Zarr dataset."""
|
|
2
|
+
|
|
3
|
+
# from itertools import islice
|
|
4
|
+
# from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
# import numpy as np
|
|
7
|
+
# import torch
|
|
8
|
+
# import zarr
|
|
9
|
+
|
|
10
|
+
# from careamics.utils import RunningStats
|
|
11
|
+
# from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
# from ..utils import normalize
|
|
14
|
+
# from .dataset_utils import read_zarr
|
|
15
|
+
# from .patching.patching import (
|
|
16
|
+
# generate_patches_unsupervised,
|
|
17
|
+
# )
|
|
18
|
+
|
|
19
|
+
# logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# class ZarrDataset(torch.utils.data.IterableDataset):
|
|
23
|
+
# """Dataset to extract patches from a zarr storage.
|
|
24
|
+
|
|
25
|
+
# Parameters
|
|
26
|
+
# ----------
|
|
27
|
+
# data_source : Union[zarr.Group, zarr.Array]
|
|
28
|
+
# Zarr storage.
|
|
29
|
+
# axes : str
|
|
30
|
+
# Description of axes in format STCZYX.
|
|
31
|
+
# patch_extraction_method : Union[ExtractionStrategies, None]
|
|
32
|
+
# Patch extraction strategy, as defined in extraction_strategy.
|
|
33
|
+
# patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
34
|
+
# Size of the patches in each dimension, by default None.
|
|
35
|
+
# num_patches : Optional[int], optional
|
|
36
|
+
# Number of patches to extract, by default None.
|
|
37
|
+
# mean : Optional[float], optional
|
|
38
|
+
# Expected mean of the dataset, by default None.
|
|
39
|
+
# std : Optional[float], optional
|
|
40
|
+
# Expected standard deviation of the dataset, by default None.
|
|
41
|
+
# patch_transform : Optional[Callable], optional
|
|
42
|
+
# Patch transform callable, by default None.
|
|
43
|
+
# patch_transform_params : Optional[Dict], optional
|
|
44
|
+
# Patch transform parameters, by default None.
|
|
45
|
+
# running_stats_window_perc : float, optional
|
|
46
|
+
# Percentage of the dataset to use for calculating the initial mean and standard
|
|
47
|
+
# deviation, by default 0.01.
|
|
48
|
+
# mode : str, optional
|
|
49
|
+
# train/predict, controls running stats calculation.
|
|
50
|
+
# """
|
|
51
|
+
|
|
52
|
+
# def __init__(
|
|
53
|
+
# self,
|
|
54
|
+
# data_source: Union[zarr.Group, zarr.Array],
|
|
55
|
+
# axes: str,
|
|
56
|
+
# patch_extraction_method: Union[SupportedExtractionStrategy, None],
|
|
57
|
+
# patch_size: Optional[Union[List[int], Tuple[int]]] = None,
|
|
58
|
+
# num_patches: Optional[int] = None,
|
|
59
|
+
# mean: Optional[float] = None,
|
|
60
|
+
# std: Optional[float] = None,
|
|
61
|
+
# patch_transform: Optional[Callable] = None,
|
|
62
|
+
# patch_transform_params: Optional[Dict] = None,
|
|
63
|
+
# running_stats_window_perc: float = 0.01,
|
|
64
|
+
# mode: str = "train",
|
|
65
|
+
# ) -> None:
|
|
66
|
+
# self.data_source = data_source
|
|
67
|
+
# self.axes = axes
|
|
68
|
+
# self.patch_extraction_method = patch_extraction_method
|
|
69
|
+
# self.patch_size = patch_size
|
|
70
|
+
# self.num_patches = num_patches
|
|
71
|
+
# self.mean = mean
|
|
72
|
+
# self.std = std
|
|
73
|
+
# self.patch_transform = patch_transform
|
|
74
|
+
# self.patch_transform_params = patch_transform_params
|
|
75
|
+
# self.sample = read_zarr(self.data_source, self.axes)
|
|
76
|
+
# self.running_stats_window = int(
|
|
77
|
+
# np.prod(self.sample._cdata_shape) * running_stats_window_perc
|
|
78
|
+
# )
|
|
79
|
+
# self.mode = mode
|
|
80
|
+
# self.running_stats = RunningStats()
|
|
81
|
+
|
|
82
|
+
# self._calculate_initial_mean_std()
|
|
83
|
+
|
|
84
|
+
# def _calculate_initial_mean_std(self):
|
|
85
|
+
# """Calculate initial mean and std of the dataset."""
|
|
86
|
+
# if self.mean is None and self.std is None:
|
|
87
|
+
# idxs = np.random.randint(
|
|
88
|
+
# 0,
|
|
89
|
+
# np.prod(self.sample._cdata_shape),
|
|
90
|
+
# size=max(1, self.running_stats_window),
|
|
91
|
+
# )
|
|
92
|
+
# random_chunks = self.sample[idxs]
|
|
93
|
+
# self.running_stats.init(random_chunks.mean(), random_chunks.std())
|
|
94
|
+
|
|
95
|
+
# def _generate_patches(self):
|
|
96
|
+
# """Generate patches from the dataset and calculates running stats.
|
|
97
|
+
|
|
98
|
+
# Yields
|
|
99
|
+
# ------
|
|
100
|
+
# np.ndarray
|
|
101
|
+
# Patch.
|
|
102
|
+
# """
|
|
103
|
+
# patches = generate_patches_unsupervised(
|
|
104
|
+
# self.sample,
|
|
105
|
+
# self.patch_extraction_method,
|
|
106
|
+
# self.patch_size,
|
|
107
|
+
# )
|
|
108
|
+
|
|
109
|
+
# # num_patches = np.ceil(
|
|
110
|
+
# # np.prod(self.sample.chunks)
|
|
111
|
+
# # / (np.prod(self.patch_size) * self.running_stats_window)
|
|
112
|
+
# # ).astype(int)
|
|
113
|
+
|
|
114
|
+
# for idx, patch in enumerate(patches):
|
|
115
|
+
# if self.mode != "predict":
|
|
116
|
+
# self.running_stats.update(patch.mean())
|
|
117
|
+
# if isinstance(patch, tuple):
|
|
118
|
+
# normalized_patch = normalize(
|
|
119
|
+
# img=patch[0],
|
|
120
|
+
# mean=self.running_stats.avg_mean.value,
|
|
121
|
+
# std=self.running_stats.avg_std.value,
|
|
122
|
+
# )
|
|
123
|
+
# patch = (normalized_patch, *patch[1:])
|
|
124
|
+
# else:
|
|
125
|
+
# patch = normalize(
|
|
126
|
+
# img=patch,
|
|
127
|
+
# mean=self.running_stats.avg_mean.value,
|
|
128
|
+
# std=self.running_stats.avg_std.value,
|
|
129
|
+
# )
|
|
130
|
+
|
|
131
|
+
# if self.patch_transform is not None:
|
|
132
|
+
# assert self.patch_transform_params is not None
|
|
133
|
+
# patch = self.patch_transform(patch, **self.patch_transform_params)
|
|
134
|
+
# if self.num_patches is not None and idx >= self.num_patches:
|
|
135
|
+
# return
|
|
136
|
+
# else:
|
|
137
|
+
# yield patch
|
|
138
|
+
# self.mean = self.running_stats.avg_mean.value
|
|
139
|
+
# self.std = self.running_stats.avg_std.value
|
|
140
|
+
|
|
141
|
+
# def __iter__(self):
|
|
142
|
+
# """
|
|
143
|
+
# Iterate over data source and yield single patch.
|
|
144
|
+
|
|
145
|
+
# Yields
|
|
146
|
+
# ------
|
|
147
|
+
# np.ndarray
|
|
148
|
+
# """
|
|
149
|
+
# worker_info = torch.utils.data.get_worker_info()
|
|
150
|
+
# num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
151
|
+
# yield from islice(self._generate_patches(), 0, None, num_workers)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Functions relating reading and writing image files."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"read",
|
|
5
|
+
"write",
|
|
6
|
+
"get_read_func",
|
|
7
|
+
"get_write_func",
|
|
8
|
+
"ReadFunc",
|
|
9
|
+
"WriteFunc",
|
|
10
|
+
"SupportedWriteType",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from . import read, write
|
|
14
|
+
from .read import ReadFunc, get_read_func
|
|
15
|
+
from .write import SupportedWriteType, WriteFunc, get_write_func
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Functions relating to reading image files of different formats."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"get_read_func",
|
|
5
|
+
"read_tiff",
|
|
6
|
+
"read_zarr",
|
|
7
|
+
"ReadFunc",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from .get_func import ReadFunc, get_read_func
|
|
11
|
+
from .tiff import read_tiff
|
|
12
|
+
from .zarr import read_zarr
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Module to get read functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Dict, Protocol, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import read_tiff
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# This is very strict, function signature has to match including arg names
|
|
14
|
+
# See WriteFunc notes
|
|
15
|
+
class ReadFunc(Protocol):
|
|
16
|
+
"""Protocol for type hinting read functions."""
|
|
17
|
+
|
|
18
|
+
def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
|
|
19
|
+
"""
|
|
20
|
+
Type hinted callables must match this function signature (not including self).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
file_path : pathlib.Path
|
|
25
|
+
Path to file.
|
|
26
|
+
*args
|
|
27
|
+
Other positional arguments.
|
|
28
|
+
**kwargs
|
|
29
|
+
Other keyword arguments.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
READ_FUNCS: Dict[SupportedData, ReadFunc] = {
|
|
34
|
+
SupportedData.TIFF: read_tiff,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
|
|
39
|
+
"""
|
|
40
|
+
Get the read function for the data type.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
data_type : SupportedData
|
|
45
|
+
Data type.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
callable
|
|
50
|
+
Read function.
|
|
51
|
+
"""
|
|
52
|
+
if data_type in READ_FUNCS:
|
|
53
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
54
|
+
return READ_FUNCS[data_type]
|
|
55
|
+
else:
|
|
56
|
+
raise NotImplementedError(f"Data type '{data_type}' is not supported.")
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Funtions to read tiff images."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from fnmatch import fnmatch
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import tifffile
|
|
9
|
+
|
|
10
|
+
from careamics.config.support import SupportedData
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Read a tiff file and return a numpy array.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
file_path : Path
|
|
23
|
+
Path to a file.
|
|
24
|
+
*args : list
|
|
25
|
+
Additional arguments.
|
|
26
|
+
**kwargs : dict
|
|
27
|
+
Additional keyword arguments.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
np.ndarray
|
|
32
|
+
Resulting array.
|
|
33
|
+
|
|
34
|
+
Raises
|
|
35
|
+
------
|
|
36
|
+
ValueError
|
|
37
|
+
If the file failed to open.
|
|
38
|
+
OSError
|
|
39
|
+
If the file failed to open.
|
|
40
|
+
ValueError
|
|
41
|
+
If the file is not a valid tiff.
|
|
42
|
+
ValueError
|
|
43
|
+
If the data dimensions are incorrect.
|
|
44
|
+
ValueError
|
|
45
|
+
If the axes length is incorrect.
|
|
46
|
+
"""
|
|
47
|
+
if fnmatch(
|
|
48
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
49
|
+
):
|
|
50
|
+
try:
|
|
51
|
+
array = tifffile.imread(file_path)
|
|
52
|
+
except (ValueError, OSError) as e:
|
|
53
|
+
logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
|
|
54
|
+
raise e
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
57
|
+
|
|
58
|
+
return array
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Function to read zarr images."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from zarr import Group, core, hierarchy, storage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def read_zarr(
|
|
9
|
+
zarr_source: Group, axes: str
|
|
10
|
+
) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
|
|
11
|
+
"""Read a file and returns a pointer.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
zarr_source : Group
|
|
16
|
+
Zarr storage.
|
|
17
|
+
axes : str
|
|
18
|
+
Axes of the data.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
np.ndarray
|
|
23
|
+
Pointer to zarr storage.
|
|
24
|
+
|
|
25
|
+
Raises
|
|
26
|
+
------
|
|
27
|
+
ValueError, OSError
|
|
28
|
+
if a file is not a valid tiff or damaged.
|
|
29
|
+
ValueError
|
|
30
|
+
if data dimensions are not 2, 3 or 4.
|
|
31
|
+
ValueError
|
|
32
|
+
if axes parameter from config is not consistent with data dimensions.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(zarr_source, hierarchy.Group):
|
|
35
|
+
array = zarr_source[0]
|
|
36
|
+
|
|
37
|
+
elif isinstance(zarr_source, storage.DirectoryStore):
|
|
38
|
+
raise NotImplementedError("DirectoryStore not supported yet")
|
|
39
|
+
|
|
40
|
+
elif isinstance(zarr_source, core.Array):
|
|
41
|
+
# array should be of shape (S, (C), (Z), Y, X), iterating over S ?
|
|
42
|
+
if zarr_source.dtype == "O":
|
|
43
|
+
raise NotImplementedError("Object type not supported yet")
|
|
44
|
+
else:
|
|
45
|
+
array = zarr_source
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")
|
|
48
|
+
|
|
49
|
+
# sanity check on dimensions
|
|
50
|
+
if len(array.shape) < 2 or len(array.shape) > 4:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# sanity check on axes length
|
|
56
|
+
if len(axes) != len(array.shape):
|
|
57
|
+
raise ValueError(f"Incorrect axes length (got {axes}).")
|
|
58
|
+
|
|
59
|
+
# arr = fix_axes(arr, axes)
|
|
60
|
+
return array
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Functions relating to writing image files of different formats."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"get_write_func",
|
|
5
|
+
"write_tiff",
|
|
6
|
+
"WriteFunc",
|
|
7
|
+
"SupportedWriteType",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from .get_func import (
|
|
11
|
+
SupportedWriteType,
|
|
12
|
+
WriteFunc,
|
|
13
|
+
get_write_func,
|
|
14
|
+
)
|
|
15
|
+
from .tiff import write_tiff
|