careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,64 @@
1
+ """Patch validation functions."""
2
+
3
+ from typing import List, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+
8
+ def validate_patch_dimensions(
9
+ arr: np.ndarray,
10
+ patch_size: Union[List[int], Tuple[int, ...]],
11
+ is_3d_patch: bool,
12
+ ) -> None:
13
+ """
14
+ Check patch size and array compatibility.
15
+
16
+ This method validates the patch sizes with respect to the array dimensions:
17
+
18
+ - Patch must have two dimensions fewer than the array (S and C).
19
+ - Patch sizes are smaller than the corresponding array dimensions.
20
+
21
+ If one of these conditions is not met, a ValueError is raised.
22
+
23
+ This method should be called after inputs have been resized.
24
+
25
+ Parameters
26
+ ----------
27
+ arr : np.ndarray
28
+ Input array.
29
+ patch_size : Union[List[int], Tuple[int, ...]]
30
+ Size of the patches along each dimension of the array, except the first.
31
+ is_3d_patch : bool
32
+ Whether the patch is 3D or not.
33
+
34
+ Raises
35
+ ------
36
+ ValueError
37
+ If the patch size is not consistent with the array shape (one more array
38
+ dimension).
39
+ ValueError
40
+ If the patch size in Z is larger than the array dimension.
41
+ ValueError
42
+ If either of the patch sizes in X or Y is larger than the corresponding array
43
+ dimension.
44
+ """
45
+ if len(patch_size) != len(arr.shape[2:]):
46
+ raise ValueError(
47
+ f"There must be a patch size for each spatial dimensions "
48
+ f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
49
+ )
50
+
51
+ # Sanity checks on patch sizes versus array dimension
52
+ if is_3d_patch and patch_size[0] > arr.shape[-3]:
53
+ raise ValueError(
54
+ f"Z patch size is inconsistent with image shape "
55
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
56
+ f"order."
57
+ )
58
+
59
+ if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
60
+ raise ValueError(
61
+ f"At least one of YX patch dimensions is larger than the corresponding "
62
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
63
+ f"Check the axes order."
64
+ )
@@ -0,0 +1,10 @@
1
+ """Tiling functions."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "extract_tiles",
6
+ "collate_tiles",
7
+ ]
8
+
9
+ from .collate_tiles import collate_tiles
10
+ from .tiled_patching import extract_tiles
@@ -0,0 +1,33 @@
1
+ """Collate function for tiling."""
2
+
3
+ from typing import Any, List, Tuple
4
+
5
+ import numpy as np
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from careamics.config.tile_information import TileInformation
9
+
10
+
11
+ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
12
+ """
13
+ Collate tiles received from CAREamics prediction dataloader.
14
+
15
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
16
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
17
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
18
+ stitch coordinates.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
23
+ Batch of tiles.
24
+
25
+ Returns
26
+ -------
27
+ Any
28
+ Collated batch.
29
+ """
30
+ new_batch = [tile for tile, _ in batch]
31
+ tiles_batch = [tile_info for _, tile_info in batch]
32
+
33
+ return default_collate(new_batch), tiles_batch
@@ -0,0 +1,282 @@
1
+ """Functions to reimplement the tiling in the Disentangle repository."""
2
+
3
+ import builtins
4
+ import itertools
5
+ from typing import Any, Generator, Optional, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.config.tile_information import TileInformation
11
+
12
+
13
+ def extract_tiles(
14
+ arr: NDArray,
15
+ tile_size: NDArray[np.int_],
16
+ overlaps: NDArray[np.int_],
17
+ padding_kwargs: Optional[dict[str, Any]] = None,
18
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
19
+ """Generate tiles from the input array with specified overlap.
20
+
21
+ The tiles cover the whole array; which will be additionally padded, to ensure that
22
+ the section of the tile that contributes to the final image comes from the center
23
+ of the tile.
24
+
25
+ The method returns a generator that yields tuples of array and tile information,
26
+ the latter includes whether the tile is the last one, the coordinates of the
27
+ overlap crop, and the coordinates of the stitched tile.
28
+
29
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
30
+ where C can be a singleton.
31
+
32
+ Parameters
33
+ ----------
34
+ arr : np.ndarray
35
+ Array of shape (S, C, (Z), Y, X).
36
+ tile_size : 1D numpy.ndarray of tuple
37
+ Tile sizes in each dimension, of length 2 or 3.
38
+ overlaps : 1D numpy.ndarray of tuple
39
+ Overlap values in each dimension, of length 2 or 3.
40
+ padding_kwargs : dict, optional
41
+ The arguments of `np.pad` after the first two arguments, `array` and
42
+ `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
43
+ `numpy.pad` docs:
44
+ https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
45
+
46
+ Yields
47
+ ------
48
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
49
+ Tile generator, yields the tile and additional information.
50
+ """
51
+ if padding_kwargs is None:
52
+ padding_kwargs = {"mode": "reflect"}
53
+
54
+ # Iterate over num samples (S)
55
+ for sample_idx in range(arr.shape[0]):
56
+ sample = arr[sample_idx, ...]
57
+ data_shape = np.array(sample.shape)
58
+
59
+ # add padding to ensure evenly spaced & overlapping tiles.
60
+ spatial_padding = compute_padding(data_shape, tile_size, overlaps)
61
+ padding = ((0, 0), *spatial_padding)
62
+ sample = np.pad(sample, padding, **padding_kwargs)
63
+
64
+ # The number of tiles in each dimension, should be of length 2 or 3
65
+ tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
66
+ # itertools.product is equivalent of nested loops
67
+
68
+ stitch_size = tile_size - overlaps
69
+ for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]):
70
+
71
+ # calculate crop coordinates
72
+ crop_coords_start = np.array(tile_grid_coords) * stitch_size
73
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
74
+ ...,
75
+ *[
76
+ slice(coords, coords + extent)
77
+ for coords, extent in zip(crop_coords_start, tile_size)
78
+ ],
79
+ )
80
+ tile = sample[crop_slices]
81
+
82
+ tile_info = compute_tile_info(
83
+ np.array(tile_grid_coords),
84
+ np.array(data_shape),
85
+ np.array(tile_size),
86
+ np.array(overlaps),
87
+ sample_idx,
88
+ )
89
+ # TODO: kinda weird this is a generator,
90
+ # -> doesn't really save memory ? Don't think there are any places the
91
+ # tiles are not exracted all at the same time.
92
+ # Although I guess it would make sense for a zarr tile extractor.
93
+ yield tile, tile_info
94
+
95
+
96
+ def compute_tile_info(
97
+ tile_grid_coords: NDArray[np.int_],
98
+ data_shape: NDArray[np.int_],
99
+ tile_size: NDArray[np.int_],
100
+ overlaps: NDArray[np.int_],
101
+ sample_id: int = 0,
102
+ ) -> TileInformation:
103
+ """
104
+ Compute the tile information for a tile with the coordinates `tile_grid_coords`.
105
+
106
+ Parameters
107
+ ----------
108
+ tile_grid_coords : 1D np.array of int
109
+ The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
110
+ tiling the coordinates for the second tile in the first row of tiles would be
111
+ (0, 1).
112
+ data_shape : 1D np.array of int
113
+ The shape of the data, should be (C, (Z), Y, X) where Z is optional.
114
+ tile_size : 1D np.array of int
115
+ Tile sizes in each dimension, of length 2 or 3.
116
+ overlaps : 1D np.array of int
117
+ Overlap values in each dimension, of length 2 or 3.
118
+ sample_id : int, default=0
119
+ An ID to identify which sample a tile belongs to.
120
+
121
+ Returns
122
+ -------
123
+ TileInformation
124
+ Information that describes how to crop and stitch a tile to create a full image.
125
+ """
126
+ spatial_dims_shape = data_shape[-len(tile_size) :]
127
+
128
+ # The extent of the tile which will make up part of the stitched image.
129
+ stitch_size = tile_size - overlaps
130
+ stitch_coords_start = tile_grid_coords * stitch_size
131
+ stitch_coords_end = stitch_coords_start + stitch_size
132
+
133
+ tile_coords_start = stitch_coords_start - overlaps // 2
134
+
135
+ # --- replace out of bounds indices
136
+ out_of_lower_bound = stitch_coords_start < 0
137
+ out_of_upper_bound = stitch_coords_end > spatial_dims_shape
138
+ stitch_coords_start[out_of_lower_bound] = 0
139
+ stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
140
+
141
+ # --- calculate overlap crop coords
142
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
143
+ overlap_crop_coords_end = overlap_crop_coords_start + (
144
+ stitch_coords_end - stitch_coords_start
145
+ )
146
+
147
+ # --- combine start and end
148
+ stitch_coords = tuple(
149
+ (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
150
+ )
151
+ overlap_crop_coords = tuple(
152
+ (start, end)
153
+ for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
154
+ )
155
+
156
+ # --- Check if last tile
157
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
158
+ last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all()
159
+
160
+ tile_info = TileInformation(
161
+ array_shape=data_shape,
162
+ last_tile=last_tile,
163
+ overlap_crop_coords=overlap_crop_coords,
164
+ stitch_coords=stitch_coords,
165
+ sample_id=sample_id,
166
+ )
167
+ return tile_info
168
+
169
+
170
+ def compute_padding(
171
+ data_shape: NDArray[np.int_],
172
+ tile_size: NDArray[np.int_],
173
+ overlaps: NDArray[np.int_],
174
+ ) -> tuple[tuple[int, int], ...]:
175
+ """
176
+ Calculate padding to ensure stitched data comes from the center of a tile.
177
+
178
+ Padding is added to an array with shape `data_shape` so that when tiles are
179
+ stitched together, the data used always comes from the center of a tile, even for
180
+ tiles at the boundaries of the array.
181
+
182
+ Parameters
183
+ ----------
184
+ data_shape : 1D numpy.array of int
185
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
186
+ tile_size : 1D numpy.array of int
187
+ The tile size in each dimension, ((Z), Y, X).
188
+ overlaps : 1D numpy.array of int
189
+ The tile overlap in each dimension, ((Z), Y, X).
190
+
191
+ Returns
192
+ -------
193
+ tuple of (int, int)
194
+ A tuple specifying the padding to add in each dimension, each element is a two
195
+ element tuple specifying the padding to add before and after the data. This
196
+ can be used as the `pad_width` argument to `numpy.pad`.
197
+ """
198
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
199
+ covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
200
+
201
+ pad_before = overlaps // 2
202
+ pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
203
+
204
+ return tuple((before, after) for before, after in zip(pad_before, pad_after))
205
+
206
+
207
+ def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
208
+ """Calculate the number of tiles in a specific dimension.
209
+
210
+ Parameters
211
+ ----------
212
+ axis_size : int
213
+ The length of the data for in a specific dimension.
214
+ tile_size : int
215
+ The length of the tiles in a specific dimension.
216
+ overlap : int
217
+ The tile overlap in a specific dimension.
218
+
219
+ Returns
220
+ -------
221
+ int
222
+ The number of tiles that fit in one dimension given the arguments.
223
+ """
224
+ return int(np.ceil(axis_size / (tile_size - overlap)))
225
+
226
+
227
+ def total_n_tiles(
228
+ data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
229
+ ) -> int:
230
+ """Calculate The total number of tiles over all dimensions.
231
+
232
+ Parameters
233
+ ----------
234
+ data_shape : 1D numpy.array of int
235
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
236
+ tile_size : 1D numpy.array of int
237
+ The tile size in each dimension, ((Z), Y, X).
238
+ overlaps : 1D numpy.array of int
239
+ The tile overlap in each dimension, ((Z), Y, X).
240
+
241
+
242
+ Returns
243
+ -------
244
+ int
245
+ The total number of tiles over all dimensions.
246
+ """
247
+ result = 1
248
+ # assume spatial dimension are the last dimensions so iterate backwards
249
+ for i in range(-1, -len(tile_size) - 1, -1):
250
+ result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
251
+
252
+ return result
253
+
254
+
255
+ def compute_tile_grid_shape(
256
+ data_shape: NDArray[np.int_],
257
+ tile_size: NDArray[np.int_],
258
+ overlaps: NDArray[np.int_],
259
+ ) -> tuple[int, ...]:
260
+ """Calculate the number of tiles in each dimension.
261
+
262
+ This can be thought of as a grid of tiles.
263
+
264
+ Parameters
265
+ ----------
266
+ data_shape : 1D numpy.array of int
267
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
268
+ tile_size : 1D numpy.array of int
269
+ The tile size in each dimension, ((Z), Y, X).
270
+ overlaps : 1D numpy.array of int
271
+ The tile overlap in each dimension, ((Z), Y, X).
272
+
273
+ Returns
274
+ -------
275
+ tuple of int
276
+ The number of tiles in each direction, ((Z, Y, X)).
277
+ """
278
+ shape = [0 for _ in range(len(tile_size))]
279
+ # assume spatial dimension are the last dimensions so iterate backwards
280
+ for i in range(-1, -len(tile_size) - 1, -1):
281
+ shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
282
+ return tuple(shape)
@@ -0,0 +1,164 @@
1
+ """Tiled patching utilities."""
2
+
3
+ import itertools
4
+ from typing import Generator, List, Tuple, Union
5
+
6
+ import numpy as np
7
+
8
+ from careamics.config.tile_information import TileInformation
9
+
10
+
11
+ def _compute_crop_and_stitch_coords_1d(
12
+ axis_size: int, tile_size: int, overlap: int
13
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
14
+ """
15
+ Compute the coordinates of each tile along an axis, given the overlap.
16
+
17
+ Parameters
18
+ ----------
19
+ axis_size : int
20
+ Length of the axis.
21
+ tile_size : int
22
+ Size of the tile for the given axis.
23
+ overlap : int
24
+ Size of the overlap for the given axis.
25
+
26
+ Returns
27
+ -------
28
+ Tuple[Tuple[int, ...], ...]
29
+ Tuple of all coordinates for given axis.
30
+ """
31
+ # Compute the step between tiles
32
+ step = tile_size - overlap
33
+ crop_coords = []
34
+ stitch_coords = []
35
+ overlap_crop_coords = []
36
+
37
+ # Iterate over the axis with step
38
+ for i in range(0, max(1, axis_size - overlap), step):
39
+ # Check if the tile fits within the axis
40
+ if i + tile_size <= axis_size:
41
+ # Add the coordinates to crop one tile
42
+ crop_coords.append((i, i + tile_size))
43
+
44
+ # Add the pixel coordinates of the cropped tile in the original image space
45
+ stitch_coords.append(
46
+ (
47
+ i + overlap // 2 if i > 0 else 0,
48
+ (
49
+ i + tile_size - overlap // 2
50
+ if crop_coords[-1][1] < axis_size
51
+ else axis_size
52
+ ),
53
+ )
54
+ )
55
+
56
+ # Add the coordinates to crop the overlap from the prediction.
57
+ overlap_crop_coords.append(
58
+ (
59
+ overlap // 2 if i > 0 else 0,
60
+ (
61
+ tile_size - overlap // 2
62
+ if crop_coords[-1][1] < axis_size
63
+ else tile_size
64
+ ),
65
+ )
66
+ )
67
+
68
+ # If the tile does not fit within the axis, perform the abovementioned
69
+ # operations starting from the end of the axis
70
+ else:
71
+ # if (axis_size - tile_size, axis_size) not in crop_coords:
72
+ crop_coords.append((max(0, axis_size - tile_size), axis_size))
73
+ last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
74
+ stitch_coords.append((last_tile_end_coord, axis_size))
75
+ overlap_crop_coords.append(
76
+ (tile_size - (axis_size - last_tile_end_coord), tile_size)
77
+ )
78
+ break
79
+ return crop_coords, stitch_coords, overlap_crop_coords
80
+
81
+
82
+ def extract_tiles(
83
+ arr: np.ndarray,
84
+ tile_size: Union[List[int], Tuple[int, ...]],
85
+ overlaps: Union[List[int], Tuple[int, ...]],
86
+ ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
87
+ """Generate tiles from the input array with specified overlap.
88
+
89
+ The tiles cover the whole array. The method returns a generator that yields
90
+ tuples of array and tile information, the latter includes whether
91
+ the tile is the last one, the coordinates of the overlap crop, and the coordinates
92
+ of the stitched tile.
93
+
94
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
95
+ where C can be a singleton.
96
+
97
+ Parameters
98
+ ----------
99
+ arr : np.ndarray
100
+ Array of shape (S, C, (Z), Y, X).
101
+ tile_size : Union[List[int], Tuple[int]]
102
+ Tile sizes in each dimension, of length 2 or 3.
103
+ overlaps : Union[List[int], Tuple[int]]
104
+ Overlap values in each dimension, of length 2 or 3.
105
+
106
+ Yields
107
+ ------
108
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
109
+ Tile generator, yields the tile and additional information.
110
+ """
111
+ # Iterate over num samples (S)
112
+ for sample_idx in range(arr.shape[0]):
113
+ sample: np.ndarray = arr[sample_idx, ...]
114
+
115
+ # Create a list of coordinates for cropping and stitching all axes.
116
+ # [crop coordinates, stitching coordinates, overlap crop coordinates]
117
+ # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
118
+ # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
119
+ crop_and_stitch_coords_list = [
120
+ _compute_crop_and_stitch_coords_1d(
121
+ sample.shape[i + 1], tile_size[i], overlaps[i]
122
+ )
123
+ for i in range(len(tile_size))
124
+ ]
125
+
126
+ # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
127
+ # grouped by type.
128
+ all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
129
+ *crop_and_stitch_coords_list
130
+ )
131
+
132
+ # Maximum tile index
133
+ max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
134
+
135
+ # Iterate over generated coordinate pairs:
136
+ for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
137
+ zip(
138
+ itertools.product(*all_crop_coords),
139
+ itertools.product(*all_stitch_coords),
140
+ itertools.product(*all_overlap_crop_coords),
141
+ )
142
+ ):
143
+ # Extract tile from the sample
144
+ tile: np.ndarray = sample[
145
+ (..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
146
+ ]
147
+
148
+ # Check if we are at the end of the sample by computing the length of the
149
+ # array that contains all the tiles
150
+ if tile_idx == max_tile_idx:
151
+ last_tile = True
152
+ else:
153
+ last_tile = False
154
+
155
+ # create tile information
156
+ tile_info = TileInformation(
157
+ array_shape=sample.shape,
158
+ last_tile=last_tile,
159
+ overlap_crop_coords=overlap_crop_coords,
160
+ stitch_coords=stitch_coords,
161
+ sample_id=sample_idx,
162
+ )
163
+
164
+ yield tile, tile_info