careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,190 @@
1
+ from typing import Generator, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import zarr
5
+
6
+ from .validate_patch_dimension import validate_patch_dimensions
7
+
8
+
9
+ # TOOD split in testable functions
10
+ def extract_patches_random(
11
+ arr: np.ndarray,
12
+ patch_size: Union[List[int], Tuple[int, ...]],
13
+ target: Optional[np.ndarray] = None,
14
+ ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
15
+ """
16
+ Generate patches from an array in a random manner.
17
+
18
+ The method calculates how many patches the image can be divided into and then
19
+ extracts an equal number of random patches.
20
+
21
+ It returns a generator that yields the following:
22
+
23
+ - patch: np.ndarray, dimension C(Z)YX.
24
+ - target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None
25
+ otherwise.
26
+
27
+ Parameters
28
+ ----------
29
+ arr : np.ndarray
30
+ Input image array.
31
+ patch_size : Tuple[int]
32
+ Patch sizes in each dimension.
33
+
34
+ Yields
35
+ ------
36
+ Generator[np.ndarray, None, None]
37
+ Generator of patches.
38
+ """
39
+ is_3d_patch = len(patch_size) == 3
40
+
41
+ # patches sanity check
42
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
43
+
44
+ # Update patch size to encompass S and C dimensions
45
+ patch_size = [1, arr.shape[1], *patch_size]
46
+
47
+ # random generator
48
+ rng = np.random.default_rng()
49
+
50
+ # iterate over the number of samples (S or T)
51
+ for sample_idx in range(arr.shape[0]):
52
+ # get sample array
53
+ sample: np.ndarray = arr[sample_idx, ...]
54
+
55
+ # same for target
56
+ if target is not None:
57
+ target_sample: np.ndarray = target[sample_idx, ...]
58
+
59
+ # calculate the number of patches
60
+ n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
61
+
62
+ # iterate over the number of patches
63
+ for _ in range(n_patches):
64
+ # get crop coordinates
65
+ crop_coords = [
66
+ rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True)
67
+ for i in range(len(patch_size[1:]))
68
+ ]
69
+
70
+ # extract patch
71
+ patch = (
72
+ sample[
73
+ (
74
+ ..., # type: ignore
75
+ *[ # type: ignore
76
+ slice(c, c + patch_size[1:][i])
77
+ for i, c in enumerate(crop_coords)
78
+ ],
79
+ )
80
+ ]
81
+ .copy()
82
+ .astype(np.float32)
83
+ )
84
+
85
+ # same for target
86
+ if target is not None:
87
+ target_patch = (
88
+ target_sample[
89
+ (
90
+ ..., # type: ignore
91
+ *[ # type: ignore
92
+ slice(c, c + patch_size[1:][i])
93
+ for i, c in enumerate(crop_coords)
94
+ ],
95
+ )
96
+ ]
97
+ .copy()
98
+ .astype(np.float32)
99
+ )
100
+ # return patch and target patch
101
+ yield patch, target_patch
102
+ else:
103
+ # return patch
104
+ yield patch, None
105
+
106
+
107
+ def extract_patches_random_from_chunks(
108
+ arr: zarr.Array,
109
+ patch_size: Union[List[int], Tuple[int, ...]],
110
+ chunk_size: Union[List[int], Tuple[int, ...]],
111
+ chunk_limit: Optional[int] = None,
112
+ ) -> Generator[np.ndarray, None, None]:
113
+ """
114
+ Generate patches from an array in a random manner.
115
+
116
+ The method calculates how many patches the image can be divided into and then
117
+ extracts an equal number of random patches.
118
+
119
+ Parameters
120
+ ----------
121
+ arr : np.ndarray
122
+ Input image array.
123
+ patch_size : Tuple[int]
124
+ Patch sizes in each dimension.
125
+ chunk_size : Tuple[int]
126
+ Chunk sizes to load from the.
127
+
128
+ Yields
129
+ ------
130
+ Generator[np.ndarray, None, None]
131
+ Generator of patches.
132
+ """
133
+ is_3d_patch = len(patch_size) == 3
134
+
135
+ # Patches sanity check
136
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
137
+
138
+ rng = np.random.default_rng()
139
+ num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
140
+
141
+ # Iterate over num chunks in the array
142
+ for _ in range(num_chunks):
143
+ chunk_crop_coords = [
144
+ rng.integers(0, max(0, arr.shape[i] - chunk_size[i]), endpoint=True)
145
+ for i in range(len(chunk_size))
146
+ ]
147
+ chunk = arr[
148
+ (
149
+ ...,
150
+ *[slice(c, c + chunk_size[i]) for i, c in enumerate(chunk_crop_coords)],
151
+ )
152
+ ].squeeze()
153
+
154
+ # Add a singleton dimension if the chunk does not have a sample dimension
155
+ if len(chunk.shape) == len(patch_size):
156
+ chunk = np.expand_dims(chunk, axis=0)
157
+
158
+ # Iterate over num samples (S)
159
+ for sample_idx in range(chunk.shape[0]):
160
+ spatial_chunk = chunk[sample_idx]
161
+ assert len(spatial_chunk.shape) == len(
162
+ patch_size
163
+ ), "Requested chunk shape is not equal to patch size"
164
+
165
+ n_patches = np.ceil(
166
+ np.prod(spatial_chunk.shape) / np.prod(patch_size)
167
+ ).astype(int)
168
+
169
+ # Iterate over the number of patches
170
+ for _ in range(n_patches):
171
+ patch_crop_coords = [
172
+ rng.integers(
173
+ 0, spatial_chunk.shape[i] - patch_size[i], endpoint=True
174
+ )
175
+ for i in range(len(patch_size))
176
+ ]
177
+ patch = (
178
+ spatial_chunk[
179
+ (
180
+ ...,
181
+ *[
182
+ slice(c, c + patch_size[i])
183
+ for i, c in enumerate(patch_crop_coords)
184
+ ],
185
+ )
186
+ ]
187
+ .copy()
188
+ .astype(np.float32)
189
+ )
190
+ yield patch
@@ -0,0 +1,206 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ from skimage.util import view_as_windows
5
+
6
+ from .validate_patch_dimension import validate_patch_dimensions
7
+
8
+
9
+ def _compute_number_of_patches(
10
+ arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
11
+ ) -> Tuple[int, ...]:
12
+ """
13
+ Compute the number of patches that fit in each dimension.
14
+
15
+ Parameters
16
+ ----------
17
+ arr : Tuple[int, ...]
18
+ Shape of the input array.
19
+ patch_sizes : Tuple[int]
20
+ Shape of the patches.
21
+
22
+ Returns
23
+ -------
24
+ Tuple[int]
25
+ Number of patches in each dimension.
26
+ """
27
+ if len(arr_shape) != len(patch_sizes):
28
+ raise ValueError(
29
+ f"Array shape {arr_shape} and patch size {patch_sizes} should have the "
30
+ f"same dimension, including singleton dimension for S and equal dimension "
31
+ f"for C."
32
+ )
33
+
34
+ try:
35
+ n_patches = [
36
+ np.ceil(arr_shape[i] / patch_sizes[i]).astype(int)
37
+ for i in range(len(patch_sizes))
38
+ ]
39
+ except IndexError as e:
40
+ raise ValueError(
41
+ f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}"
42
+ ) from e
43
+
44
+ return tuple(n_patches)
45
+
46
+
47
+ def _compute_overlap(
48
+ arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
49
+ ) -> Tuple[int, ...]:
50
+ """
51
+ Compute the overlap between patches in each dimension.
52
+
53
+ If the array dimensions are divisible by the patch sizes, then the overlap is
54
+ 0. Otherwise, it is the result of the division rounded to the upper value.
55
+
56
+ Parameters
57
+ ----------
58
+ arr : Tuple[int, ...]
59
+ Input array shape.
60
+ patch_sizes : Tuple[int]
61
+ Size of the patches.
62
+
63
+ Returns
64
+ -------
65
+ Tuple[int]
66
+ Overlap between patches in each dimension.
67
+ """
68
+ n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
69
+
70
+ overlap = [
71
+ np.ceil(
72
+ np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None)
73
+ / max(1, (n_patches[i] - 1))
74
+ ).astype(int)
75
+ for i in range(len(patch_sizes))
76
+ ]
77
+ return tuple(overlap)
78
+
79
+
80
+ def _compute_patch_steps(
81
+ patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
82
+ ) -> Tuple[int, ...]:
83
+ """
84
+ Compute steps between patches.
85
+
86
+ Parameters
87
+ ----------
88
+ patch_sizes : Tuple[int]
89
+ Size of the patches.
90
+ overlaps : Tuple[int]
91
+ Overlap between patches.
92
+
93
+ Returns
94
+ -------
95
+ Tuple[int]
96
+ Steps between patches.
97
+ """
98
+ steps = [
99
+ min(patch_sizes[i] - overlaps[i], patch_sizes[i])
100
+ for i in range(len(patch_sizes))
101
+ ]
102
+ return tuple(steps)
103
+
104
+
105
+ # TODO why stack the target here and not on a different dimension before this function?
106
+ def _compute_patch_views(
107
+ arr: np.ndarray,
108
+ window_shape: List[int],
109
+ step: Tuple[int, ...],
110
+ output_shape: List[int],
111
+ target: Optional[np.ndarray] = None,
112
+ ) -> np.ndarray:
113
+ """
114
+ Compute views of an array corresponding to patches.
115
+
116
+ Parameters
117
+ ----------
118
+ arr : np.ndarray
119
+ Array from which the views are extracted.
120
+ window_shape : Tuple[int]
121
+ Shape of the views.
122
+ step : Tuple[int]
123
+ Steps between views.
124
+ output_shape : Tuple[int]
125
+ Shape of the output array.
126
+
127
+ Returns
128
+ -------
129
+ np.ndarray
130
+ Array with views dimension.
131
+ """
132
+ rng = np.random.default_rng()
133
+
134
+ if target is not None:
135
+ arr = np.stack([arr, target], axis=0)
136
+ window_shape = [arr.shape[0], *window_shape]
137
+ step = (arr.shape[0], *step)
138
+ output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]]
139
+
140
+ patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
141
+ *output_shape
142
+ )
143
+ if target is not None:
144
+ rng.shuffle(patches, axis=1)
145
+ else:
146
+ rng.shuffle(patches, axis=0)
147
+ return patches
148
+
149
+
150
+ def extract_patches_sequential(
151
+ arr: np.ndarray,
152
+ patch_size: Union[List[int], Tuple[int, ...]],
153
+ target: Optional[np.ndarray] = None,
154
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
155
+ """
156
+ Generate patches from an array in a sequential manner.
157
+
158
+ Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The
159
+ patches are generated sequentially and cover the whole array.
160
+
161
+ Parameters
162
+ ----------
163
+ arr : np.ndarray
164
+ Input image array.
165
+ patch_size : Tuple[int]
166
+ Patch sizes in each dimension.
167
+
168
+ Returns
169
+ -------
170
+ Generator[Tuple[np.ndarray, ...], None, None]
171
+ Generator of patches.
172
+ """
173
+ is_3d_patch = len(patch_size) == 3
174
+
175
+ # Patches sanity check
176
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
177
+
178
+ # Update patch size to encompass S and C dimensions
179
+ patch_size = [1, arr.shape[1], *patch_size]
180
+
181
+ # Compute overlap
182
+ overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size)
183
+
184
+ # Create view window and overlaps
185
+ window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
186
+
187
+ output_shape = [
188
+ -1,
189
+ ] + patch_size[1:]
190
+
191
+ # Generate a view of the input array containing pre-calculated number of patches
192
+ # in each dimension with overlap.
193
+ # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X)
194
+ patches = _compute_patch_views(
195
+ arr,
196
+ window_shape=patch_size,
197
+ step=window_steps,
198
+ output_shape=output_shape,
199
+ target=target,
200
+ )
201
+
202
+ if target is not None:
203
+ # target was concatenated to patches in _compute_reshaped_view
204
+ return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view?
205
+ else:
206
+ return patches, None
@@ -0,0 +1,158 @@
1
+ import itertools
2
+ from typing import Generator, List, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from careamics.config.tile_information import TileInformation
7
+
8
+
9
+ def _compute_crop_and_stitch_coords_1d(
10
+ axis_size: int, tile_size: int, overlap: int
11
+ ) -> Tuple[List[Tuple[int, ...]], ...]:
12
+ """
13
+ Compute the coordinates of each tile along an axis, given the overlap.
14
+
15
+ Parameters
16
+ ----------
17
+ axis_size : int
18
+ Length of the axis.
19
+ tile_size : int
20
+ Size of the tile for the given axis.
21
+ overlap : int
22
+ Size of the overlap for the given axis.
23
+
24
+ Returns
25
+ -------
26
+ Tuple[Tuple[int, ...], ...]
27
+ Tuple of all coordinates for given axis.
28
+ """
29
+ # Compute the step between tiles
30
+ step = tile_size - overlap
31
+ crop_coords = []
32
+ stitch_coords = []
33
+ overlap_crop_coords = []
34
+
35
+ # Iterate over the axis with step
36
+ for i in range(0, max(1, axis_size - overlap), step):
37
+ # Check if the tile fits within the axis
38
+ if i + tile_size <= axis_size:
39
+ # Add the coordinates to crop one tile
40
+ crop_coords.append((i, i + tile_size))
41
+
42
+ # Add the pixel coordinates of the cropped tile in the original image space
43
+ stitch_coords.append(
44
+ (
45
+ i + overlap // 2 if i > 0 else 0,
46
+ i + tile_size - overlap // 2
47
+ if crop_coords[-1][1] < axis_size
48
+ else axis_size,
49
+ )
50
+ )
51
+
52
+ # Add the coordinates to crop the overlap from the prediction.
53
+ overlap_crop_coords.append(
54
+ (
55
+ overlap // 2 if i > 0 else 0,
56
+ tile_size - overlap // 2
57
+ if crop_coords[-1][1] < axis_size
58
+ else tile_size,
59
+ )
60
+ )
61
+
62
+ # If the tile does not fit within the axis, perform the abovementioned
63
+ # operations starting from the end of the axis
64
+ else:
65
+ # if (axis_size - tile_size, axis_size) not in crop_coords:
66
+ crop_coords.append((max(0, axis_size - tile_size), axis_size))
67
+ last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
68
+ stitch_coords.append((last_tile_end_coord, axis_size))
69
+ overlap_crop_coords.append(
70
+ (tile_size - (axis_size - last_tile_end_coord), tile_size)
71
+ )
72
+ break
73
+ return crop_coords, stitch_coords, overlap_crop_coords
74
+
75
+
76
+ def extract_tiles(
77
+ arr: np.ndarray,
78
+ tile_size: Union[List[int], Tuple[int, ...]],
79
+ overlaps: Union[List[int], Tuple[int, ...]],
80
+ ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
81
+ """
82
+ Generate tiles from the input array with specified overlap.
83
+
84
+ The tiles cover the whole array. The method returns a generator that yields
85
+ tuples of array and tile information, the latter includes whether
86
+ the tile is the last one, the coordinates of the overlap crop, and the coordinates
87
+ of the stitched tile.
88
+
89
+ The array has shape C(Z)YX, where C can be a singleton.
90
+
91
+ Parameters
92
+ ----------
93
+ arr : np.ndarray
94
+ Array of shape (S, C, (Z), Y, X).
95
+ tile_size : Union[List[int], Tuple[int]]
96
+ Tile sizes in each dimension, of length 2 or 3.
97
+ overlaps : Union[List[int], Tuple[int]]
98
+ Overlap values in each dimension, of length 2 or 3.
99
+
100
+ Yields
101
+ ------
102
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
103
+ Tile generator, yields the tile and additional information.
104
+ """
105
+ # Iterate over num samples (S)
106
+ for sample_idx in range(arr.shape[0]):
107
+ sample: np.ndarray = arr[sample_idx, ...]
108
+
109
+ # Create a list of coordinates for cropping and stitching all axes.
110
+ # [crop coordinates, stitching coordinates, overlap crop coordinates]
111
+ # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
112
+ # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
113
+ crop_and_stitch_coords_list = [
114
+ _compute_crop_and_stitch_coords_1d(
115
+ sample.shape[i + 1], tile_size[i], overlaps[i]
116
+ )
117
+ for i in range(len(tile_size))
118
+ ]
119
+
120
+ # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
121
+ # grouped by type.
122
+ all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
123
+ *crop_and_stitch_coords_list
124
+ )
125
+
126
+ # Maximum tile index
127
+ max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
128
+
129
+ # Iterate over generated coordinate pairs:
130
+ for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
131
+ zip(
132
+ itertools.product(*all_crop_coords),
133
+ itertools.product(*all_stitch_coords),
134
+ itertools.product(*all_overlap_crop_coords),
135
+ )
136
+ ):
137
+ # Extract tile from the sample
138
+ tile: np.ndarray = sample[
139
+ (..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
140
+ ]
141
+
142
+ # Check if we are at the end of the sample by computing the length of the
143
+ # array that contains all the tiles
144
+ if tile_idx == max_tile_idx:
145
+ last_tile = True
146
+ else:
147
+ last_tile = False
148
+
149
+ # create tile information
150
+ tile_info = TileInformation(
151
+ array_shape=sample.squeeze().shape,
152
+ tiled=True,
153
+ last_tile=last_tile,
154
+ overlap_crop_coords=overlap_crop_coords,
155
+ stitch_coords=stitch_coords,
156
+ )
157
+
158
+ yield tile, tile_info
@@ -0,0 +1,60 @@
1
+ from typing import List, Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+
6
+ def validate_patch_dimensions(
7
+ arr: np.ndarray,
8
+ patch_size: Union[List[int], Tuple[int, ...]],
9
+ is_3d_patch: bool,
10
+ ) -> None:
11
+ """
12
+ Check patch size and array compatibility.
13
+
14
+ This method validates the patch sizes with respect to the array dimensions:
15
+
16
+ - Patch must have two dimensions fewer than the array (S and C).
17
+ - Patch sizes are smaller than the corresponding array dimensions.
18
+
19
+ If one of these conditions is not met, a ValueError is raised.
20
+
21
+ This method should be called after inputs have been resized.
22
+
23
+ Parameters
24
+ ----------
25
+ arr : np.ndarray
26
+ Input array.
27
+ patch_size : Union[List[int], Tuple[int, ...]]
28
+ Size of the patches along each dimension of the array, except the first.
29
+ is_3d_patch : bool
30
+ Whether the patch is 3D or not.
31
+
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If the patch size is not consistent with the array shape (one more array
36
+ dimension).
37
+ ValueError
38
+ If the patch size in Z is larger than the array dimension.
39
+ ValueError
40
+ If either of the patch sizes in X or Y is larger than the corresponding array
41
+ dimension.
42
+ """
43
+ if len(patch_size) != len(arr.shape[2:]):
44
+ raise ValueError(
45
+ f"There must be a patch size for each spatial dimensions "
46
+ f"(got {patch_size} patches for dims {arr.shape})."
47
+ )
48
+
49
+ # Sanity checks on patch sizes versus array dimension
50
+ if is_3d_patch and patch_size[0] > arr.shape[-3]:
51
+ raise ValueError(
52
+ f"Z patch size is inconsistent with image shape "
53
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
54
+ )
55
+
56
+ if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
57
+ raise ValueError(
58
+ f"At least one of YX patch dimensions is larger than the corresponding "
59
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
60
+ )