careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Patch validation functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def validate_patch_dimensions(
|
|
9
|
+
arr: np.ndarray,
|
|
10
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
11
|
+
is_3d_patch: bool,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Check patch size and array compatibility.
|
|
15
|
+
|
|
16
|
+
This method validates the patch sizes with respect to the array dimensions:
|
|
17
|
+
|
|
18
|
+
- Patch must have two dimensions fewer than the array (S and C).
|
|
19
|
+
- Patch sizes are smaller than the corresponding array dimensions.
|
|
20
|
+
|
|
21
|
+
If one of these conditions is not met, a ValueError is raised.
|
|
22
|
+
|
|
23
|
+
This method should be called after inputs have been resized.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
arr : np.ndarray
|
|
28
|
+
Input array.
|
|
29
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
30
|
+
Size of the patches along each dimension of the array, except the first.
|
|
31
|
+
is_3d_patch : bool
|
|
32
|
+
Whether the patch is 3D or not.
|
|
33
|
+
|
|
34
|
+
Raises
|
|
35
|
+
------
|
|
36
|
+
ValueError
|
|
37
|
+
If the patch size is not consistent with the array shape (one more array
|
|
38
|
+
dimension).
|
|
39
|
+
ValueError
|
|
40
|
+
If the patch size in Z is larger than the array dimension.
|
|
41
|
+
ValueError
|
|
42
|
+
If either of the patch sizes in X or Y is larger than the corresponding array
|
|
43
|
+
dimension.
|
|
44
|
+
"""
|
|
45
|
+
if len(patch_size) != len(arr.shape[2:]):
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"There must be a patch size for each spatial dimensions "
|
|
48
|
+
f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Sanity checks on patch sizes versus array dimension
|
|
52
|
+
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Z patch size is inconsistent with image shape "
|
|
55
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
|
|
56
|
+
f"order."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
62
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
|
|
63
|
+
f"Check the axes order."
|
|
64
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Collate function for tiling."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data.dataloader import default_collate
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
14
|
+
|
|
15
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
16
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
17
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
18
|
+
stitch coordinates.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
23
|
+
Batch of tiles.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Any
|
|
28
|
+
Collated batch.
|
|
29
|
+
"""
|
|
30
|
+
new_batch = [tile for tile, _ in batch]
|
|
31
|
+
tiles_batch = [tile_info for _, tile_info in batch]
|
|
32
|
+
|
|
33
|
+
return default_collate(new_batch), tiles_batch
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Functions to reimplement the tiling in the Disentangle repository."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
import itertools
|
|
5
|
+
from typing import Any, Generator, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from careamics.config.tile_information import TileInformation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_tiles(
|
|
14
|
+
arr: NDArray,
|
|
15
|
+
tile_size: NDArray[np.int_],
|
|
16
|
+
overlaps: NDArray[np.int_],
|
|
17
|
+
padding_kwargs: Optional[dict[str, Any]] = None,
|
|
18
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
19
|
+
"""Generate tiles from the input array with specified overlap.
|
|
20
|
+
|
|
21
|
+
The tiles cover the whole array; which will be additionally padded, to ensure that
|
|
22
|
+
the section of the tile that contributes to the final image comes from the center
|
|
23
|
+
of the tile.
|
|
24
|
+
|
|
25
|
+
The method returns a generator that yields tuples of array and tile information,
|
|
26
|
+
the latter includes whether the tile is the last one, the coordinates of the
|
|
27
|
+
overlap crop, and the coordinates of the stitched tile.
|
|
28
|
+
|
|
29
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
30
|
+
where C can be a singleton.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
arr : np.ndarray
|
|
35
|
+
Array of shape (S, C, (Z), Y, X).
|
|
36
|
+
tile_size : 1D numpy.ndarray of tuple
|
|
37
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
38
|
+
overlaps : 1D numpy.ndarray of tuple
|
|
39
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
40
|
+
padding_kwargs : dict, optional
|
|
41
|
+
The arguments of `np.pad` after the first two arguments, `array` and
|
|
42
|
+
`pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
|
|
43
|
+
`numpy.pad` docs:
|
|
44
|
+
https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
|
|
45
|
+
|
|
46
|
+
Yields
|
|
47
|
+
------
|
|
48
|
+
Generator[Tuple[np.ndarray, TileInformation], None, None]
|
|
49
|
+
Tile generator, yields the tile and additional information.
|
|
50
|
+
"""
|
|
51
|
+
if padding_kwargs is None:
|
|
52
|
+
padding_kwargs = {"mode": "reflect"}
|
|
53
|
+
|
|
54
|
+
# Iterate over num samples (S)
|
|
55
|
+
for sample_idx in range(arr.shape[0]):
|
|
56
|
+
sample = arr[sample_idx, ...]
|
|
57
|
+
data_shape = np.array(sample.shape)
|
|
58
|
+
|
|
59
|
+
# add padding to ensure evenly spaced & overlapping tiles.
|
|
60
|
+
spatial_padding = compute_padding(data_shape, tile_size, overlaps)
|
|
61
|
+
padding = ((0, 0), *spatial_padding)
|
|
62
|
+
sample = np.pad(sample, padding, **padding_kwargs)
|
|
63
|
+
|
|
64
|
+
# The number of tiles in each dimension, should be of length 2 or 3
|
|
65
|
+
tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
|
|
66
|
+
# itertools.product is equivalent of nested loops
|
|
67
|
+
|
|
68
|
+
stitch_size = tile_size - overlaps
|
|
69
|
+
for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]):
|
|
70
|
+
|
|
71
|
+
# calculate crop coordinates
|
|
72
|
+
crop_coords_start = np.array(tile_grid_coords) * stitch_size
|
|
73
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
74
|
+
...,
|
|
75
|
+
*[
|
|
76
|
+
slice(coords, coords + extent)
|
|
77
|
+
for coords, extent in zip(crop_coords_start, tile_size)
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
tile = sample[crop_slices]
|
|
81
|
+
|
|
82
|
+
tile_info = compute_tile_info(
|
|
83
|
+
np.array(tile_grid_coords),
|
|
84
|
+
np.array(data_shape),
|
|
85
|
+
np.array(tile_size),
|
|
86
|
+
np.array(overlaps),
|
|
87
|
+
sample_idx,
|
|
88
|
+
)
|
|
89
|
+
# TODO: kinda weird this is a generator,
|
|
90
|
+
# -> doesn't really save memory ? Don't think there are any places the
|
|
91
|
+
# tiles are not exracted all at the same time.
|
|
92
|
+
# Although I guess it would make sense for a zarr tile extractor.
|
|
93
|
+
yield tile, tile_info
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_tile_info(
|
|
97
|
+
tile_grid_coords: NDArray[np.int_],
|
|
98
|
+
data_shape: NDArray[np.int_],
|
|
99
|
+
tile_size: NDArray[np.int_],
|
|
100
|
+
overlaps: NDArray[np.int_],
|
|
101
|
+
sample_id: int = 0,
|
|
102
|
+
) -> TileInformation:
|
|
103
|
+
"""
|
|
104
|
+
Compute the tile information for a tile with the coordinates `tile_grid_coords`.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
tile_grid_coords : 1D np.array of int
|
|
109
|
+
The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
|
|
110
|
+
tiling the coordinates for the second tile in the first row of tiles would be
|
|
111
|
+
(0, 1).
|
|
112
|
+
data_shape : 1D np.array of int
|
|
113
|
+
The shape of the data, should be (C, (Z), Y, X) where Z is optional.
|
|
114
|
+
tile_size : 1D np.array of int
|
|
115
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
116
|
+
overlaps : 1D np.array of int
|
|
117
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
118
|
+
sample_id : int, default=0
|
|
119
|
+
An ID to identify which sample a tile belongs to.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
TileInformation
|
|
124
|
+
Information that describes how to crop and stitch a tile to create a full image.
|
|
125
|
+
"""
|
|
126
|
+
spatial_dims_shape = data_shape[-len(tile_size) :]
|
|
127
|
+
|
|
128
|
+
# The extent of the tile which will make up part of the stitched image.
|
|
129
|
+
stitch_size = tile_size - overlaps
|
|
130
|
+
stitch_coords_start = tile_grid_coords * stitch_size
|
|
131
|
+
stitch_coords_end = stitch_coords_start + stitch_size
|
|
132
|
+
|
|
133
|
+
tile_coords_start = stitch_coords_start - overlaps // 2
|
|
134
|
+
|
|
135
|
+
# --- replace out of bounds indices
|
|
136
|
+
out_of_lower_bound = stitch_coords_start < 0
|
|
137
|
+
out_of_upper_bound = stitch_coords_end > spatial_dims_shape
|
|
138
|
+
stitch_coords_start[out_of_lower_bound] = 0
|
|
139
|
+
stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
|
|
140
|
+
|
|
141
|
+
# --- calculate overlap crop coords
|
|
142
|
+
overlap_crop_coords_start = stitch_coords_start - tile_coords_start
|
|
143
|
+
overlap_crop_coords_end = overlap_crop_coords_start + (
|
|
144
|
+
stitch_coords_end - stitch_coords_start
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# --- combine start and end
|
|
148
|
+
stitch_coords = tuple(
|
|
149
|
+
(start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
|
|
150
|
+
)
|
|
151
|
+
overlap_crop_coords = tuple(
|
|
152
|
+
(start, end)
|
|
153
|
+
for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# --- Check if last tile
|
|
157
|
+
tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
|
|
158
|
+
last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all()
|
|
159
|
+
|
|
160
|
+
tile_info = TileInformation(
|
|
161
|
+
array_shape=data_shape,
|
|
162
|
+
last_tile=last_tile,
|
|
163
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
164
|
+
stitch_coords=stitch_coords,
|
|
165
|
+
sample_id=sample_id,
|
|
166
|
+
)
|
|
167
|
+
return tile_info
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def compute_padding(
|
|
171
|
+
data_shape: NDArray[np.int_],
|
|
172
|
+
tile_size: NDArray[np.int_],
|
|
173
|
+
overlaps: NDArray[np.int_],
|
|
174
|
+
) -> tuple[tuple[int, int], ...]:
|
|
175
|
+
"""
|
|
176
|
+
Calculate padding to ensure stitched data comes from the center of a tile.
|
|
177
|
+
|
|
178
|
+
Padding is added to an array with shape `data_shape` so that when tiles are
|
|
179
|
+
stitched together, the data used always comes from the center of a tile, even for
|
|
180
|
+
tiles at the boundaries of the array.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
data_shape : 1D numpy.array of int
|
|
185
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
186
|
+
tile_size : 1D numpy.array of int
|
|
187
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
188
|
+
overlaps : 1D numpy.array of int
|
|
189
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
tuple of (int, int)
|
|
194
|
+
A tuple specifying the padding to add in each dimension, each element is a two
|
|
195
|
+
element tuple specifying the padding to add before and after the data. This
|
|
196
|
+
can be used as the `pad_width` argument to `numpy.pad`.
|
|
197
|
+
"""
|
|
198
|
+
tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
|
|
199
|
+
covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
|
|
200
|
+
|
|
201
|
+
pad_before = overlaps // 2
|
|
202
|
+
pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
|
|
203
|
+
|
|
204
|
+
return tuple((before, after) for before, after in zip(pad_before, pad_after))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
|
|
208
|
+
"""Calculate the number of tiles in a specific dimension.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
axis_size : int
|
|
213
|
+
The length of the data for in a specific dimension.
|
|
214
|
+
tile_size : int
|
|
215
|
+
The length of the tiles in a specific dimension.
|
|
216
|
+
overlap : int
|
|
217
|
+
The tile overlap in a specific dimension.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
int
|
|
222
|
+
The number of tiles that fit in one dimension given the arguments.
|
|
223
|
+
"""
|
|
224
|
+
return int(np.ceil(axis_size / (tile_size - overlap)))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def total_n_tiles(
|
|
228
|
+
data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
|
|
229
|
+
) -> int:
|
|
230
|
+
"""Calculate The total number of tiles over all dimensions.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
data_shape : 1D numpy.array of int
|
|
235
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
236
|
+
tile_size : 1D numpy.array of int
|
|
237
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
238
|
+
overlaps : 1D numpy.array of int
|
|
239
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
int
|
|
245
|
+
The total number of tiles over all dimensions.
|
|
246
|
+
"""
|
|
247
|
+
result = 1
|
|
248
|
+
# assume spatial dimension are the last dimensions so iterate backwards
|
|
249
|
+
for i in range(-1, -len(tile_size) - 1, -1):
|
|
250
|
+
result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
|
|
251
|
+
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def compute_tile_grid_shape(
|
|
256
|
+
data_shape: NDArray[np.int_],
|
|
257
|
+
tile_size: NDArray[np.int_],
|
|
258
|
+
overlaps: NDArray[np.int_],
|
|
259
|
+
) -> tuple[int, ...]:
|
|
260
|
+
"""Calculate the number of tiles in each dimension.
|
|
261
|
+
|
|
262
|
+
This can be thought of as a grid of tiles.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
data_shape : 1D numpy.array of int
|
|
267
|
+
The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
|
|
268
|
+
tile_size : 1D numpy.array of int
|
|
269
|
+
The tile size in each dimension, ((Z), Y, X).
|
|
270
|
+
overlaps : 1D numpy.array of int
|
|
271
|
+
The tile overlap in each dimension, ((Z), Y, X).
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
tuple of int
|
|
276
|
+
The number of tiles in each direction, ((Z, Y, X)).
|
|
277
|
+
"""
|
|
278
|
+
shape = [0 for _ in range(len(tile_size))]
|
|
279
|
+
# assume spatial dimension are the last dimensions so iterate backwards
|
|
280
|
+
for i in range(-1, -len(tile_size) - 1, -1):
|
|
281
|
+
shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
|
|
282
|
+
return tuple(shape)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Tiled patching utilities."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from typing import Generator, List, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _compute_crop_and_stitch_coords_1d(
|
|
12
|
+
axis_size: int, tile_size: int, overlap: int
|
|
13
|
+
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
|
|
14
|
+
"""
|
|
15
|
+
Compute the coordinates of each tile along an axis, given the overlap.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
axis_size : int
|
|
20
|
+
Length of the axis.
|
|
21
|
+
tile_size : int
|
|
22
|
+
Size of the tile for the given axis.
|
|
23
|
+
overlap : int
|
|
24
|
+
Size of the overlap for the given axis.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Tuple[Tuple[int, ...], ...]
|
|
29
|
+
Tuple of all coordinates for given axis.
|
|
30
|
+
"""
|
|
31
|
+
# Compute the step between tiles
|
|
32
|
+
step = tile_size - overlap
|
|
33
|
+
crop_coords = []
|
|
34
|
+
stitch_coords = []
|
|
35
|
+
overlap_crop_coords = []
|
|
36
|
+
|
|
37
|
+
# Iterate over the axis with step
|
|
38
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
39
|
+
# Check if the tile fits within the axis
|
|
40
|
+
if i + tile_size <= axis_size:
|
|
41
|
+
# Add the coordinates to crop one tile
|
|
42
|
+
crop_coords.append((i, i + tile_size))
|
|
43
|
+
|
|
44
|
+
# Add the pixel coordinates of the cropped tile in the original image space
|
|
45
|
+
stitch_coords.append(
|
|
46
|
+
(
|
|
47
|
+
i + overlap // 2 if i > 0 else 0,
|
|
48
|
+
(
|
|
49
|
+
i + tile_size - overlap // 2
|
|
50
|
+
if crop_coords[-1][1] < axis_size
|
|
51
|
+
else axis_size
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Add the coordinates to crop the overlap from the prediction.
|
|
57
|
+
overlap_crop_coords.append(
|
|
58
|
+
(
|
|
59
|
+
overlap // 2 if i > 0 else 0,
|
|
60
|
+
(
|
|
61
|
+
tile_size - overlap // 2
|
|
62
|
+
if crop_coords[-1][1] < axis_size
|
|
63
|
+
else tile_size
|
|
64
|
+
),
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# If the tile does not fit within the axis, perform the abovementioned
|
|
69
|
+
# operations starting from the end of the axis
|
|
70
|
+
else:
|
|
71
|
+
# if (axis_size - tile_size, axis_size) not in crop_coords:
|
|
72
|
+
crop_coords.append((max(0, axis_size - tile_size), axis_size))
|
|
73
|
+
last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
|
|
74
|
+
stitch_coords.append((last_tile_end_coord, axis_size))
|
|
75
|
+
overlap_crop_coords.append(
|
|
76
|
+
(tile_size - (axis_size - last_tile_end_coord), tile_size)
|
|
77
|
+
)
|
|
78
|
+
break
|
|
79
|
+
return crop_coords, stitch_coords, overlap_crop_coords
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_tiles(
|
|
83
|
+
arr: np.ndarray,
|
|
84
|
+
tile_size: Union[List[int], Tuple[int, ...]],
|
|
85
|
+
overlaps: Union[List[int], Tuple[int, ...]],
|
|
86
|
+
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
87
|
+
"""Generate tiles from the input array with specified overlap.
|
|
88
|
+
|
|
89
|
+
The tiles cover the whole array. The method returns a generator that yields
|
|
90
|
+
tuples of array and tile information, the latter includes whether
|
|
91
|
+
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
92
|
+
of the stitched tile.
|
|
93
|
+
|
|
94
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
95
|
+
where C can be a singleton.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
arr : np.ndarray
|
|
100
|
+
Array of shape (S, C, (Z), Y, X).
|
|
101
|
+
tile_size : Union[List[int], Tuple[int]]
|
|
102
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
103
|
+
overlaps : Union[List[int], Tuple[int]]
|
|
104
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
105
|
+
|
|
106
|
+
Yields
|
|
107
|
+
------
|
|
108
|
+
Generator[Tuple[np.ndarray, TileInformation], None, None]
|
|
109
|
+
Tile generator, yields the tile and additional information.
|
|
110
|
+
"""
|
|
111
|
+
# Iterate over num samples (S)
|
|
112
|
+
for sample_idx in range(arr.shape[0]):
|
|
113
|
+
sample: np.ndarray = arr[sample_idx, ...]
|
|
114
|
+
|
|
115
|
+
# Create a list of coordinates for cropping and stitching all axes.
|
|
116
|
+
# [crop coordinates, stitching coordinates, overlap crop coordinates]
|
|
117
|
+
# For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
|
|
118
|
+
# will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
|
|
119
|
+
crop_and_stitch_coords_list = [
|
|
120
|
+
_compute_crop_and_stitch_coords_1d(
|
|
121
|
+
sample.shape[i + 1], tile_size[i], overlaps[i]
|
|
122
|
+
)
|
|
123
|
+
for i in range(len(tile_size))
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
127
|
+
# grouped by type.
|
|
128
|
+
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
129
|
+
*crop_and_stitch_coords_list
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Maximum tile index
|
|
133
|
+
max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
|
|
134
|
+
|
|
135
|
+
# Iterate over generated coordinate pairs:
|
|
136
|
+
for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
|
|
137
|
+
zip(
|
|
138
|
+
itertools.product(*all_crop_coords),
|
|
139
|
+
itertools.product(*all_stitch_coords),
|
|
140
|
+
itertools.product(*all_overlap_crop_coords),
|
|
141
|
+
)
|
|
142
|
+
):
|
|
143
|
+
# Extract tile from the sample
|
|
144
|
+
tile: np.ndarray = sample[
|
|
145
|
+
(..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
# Check if we are at the end of the sample by computing the length of the
|
|
149
|
+
# array that contains all the tiles
|
|
150
|
+
if tile_idx == max_tile_idx:
|
|
151
|
+
last_tile = True
|
|
152
|
+
else:
|
|
153
|
+
last_tile = False
|
|
154
|
+
|
|
155
|
+
# create tile information
|
|
156
|
+
tile_info = TileInformation(
|
|
157
|
+
array_shape=sample.shape,
|
|
158
|
+
last_tile=last_tile,
|
|
159
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
160
|
+
stitch_coords=stitch_coords,
|
|
161
|
+
sample_id=sample_idx,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
yield tile, tile_info
|