careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""Module contiaing tiling manager class."""
|
|
2
|
+
|
|
3
|
+
# # TODO: remove this file, left as a reference for now.
|
|
4
|
+
|
|
5
|
+
# from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
# import numpy as np
|
|
8
|
+
# from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
# from careamics.config.tile_information import TileInformation
|
|
11
|
+
# from careamics.config.validators import check_axes_validity
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# def calculate_padding(
|
|
15
|
+
# patch_start_location: NDArray,
|
|
16
|
+
# patch_size: NDArray,
|
|
17
|
+
# data_shape: NDArray,
|
|
18
|
+
# ) -> NDArray:
|
|
19
|
+
# patch_end_location = patch_start_location + patch_size
|
|
20
|
+
|
|
21
|
+
# pad_before = np.zeros_like(patch_start_location)
|
|
22
|
+
# start_out_of_bounds = patch_start_location < 0
|
|
23
|
+
# pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds]
|
|
24
|
+
|
|
25
|
+
# pad_after = np.zeros_like(patch_start_location)
|
|
26
|
+
# end_out_of_bounds = patch_end_location > data_shape
|
|
27
|
+
# pad_after[end_out_of_bounds] = (
|
|
28
|
+
# patch_end_location - data_shape
|
|
29
|
+
# )[end_out_of_bounds]
|
|
30
|
+
|
|
31
|
+
# return np.stack([pad_before, pad_after], axis=1)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# def extract_tile(
|
|
35
|
+
# img: np.ndarray,
|
|
36
|
+
# grid_start_loc: tuple[int, ...],
|
|
37
|
+
# patch_size: tuple[int, ...],
|
|
38
|
+
# overlap: tuple[int, ...],
|
|
39
|
+
# padding: bool,
|
|
40
|
+
# padding_kwargs: Optional[dict[str, Any]] = None,
|
|
41
|
+
# ) -> NDArray:
|
|
42
|
+
# if padding_kwargs is None:
|
|
43
|
+
# padding_kwargs = {}
|
|
44
|
+
|
|
45
|
+
# data_shape = img.shape
|
|
46
|
+
# patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2
|
|
47
|
+
# crop_slices = tuple(
|
|
48
|
+
# slice(max(0, start), min(start + size, dim_shape))
|
|
49
|
+
# for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape)
|
|
50
|
+
# )
|
|
51
|
+
# crop = img[crop_slices]
|
|
52
|
+
# if padding:
|
|
53
|
+
# pad = calculate_padding(
|
|
54
|
+
# patch_start_location=patch_start_loc,
|
|
55
|
+
# patch_size=patch_size,
|
|
56
|
+
# data_shape=data_shape,
|
|
57
|
+
# )
|
|
58
|
+
# crop = np.pad(crop, pad, **padding_kwargs)
|
|
59
|
+
|
|
60
|
+
# return crop
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# class TilingManager:
|
|
64
|
+
|
|
65
|
+
# def __init__(
|
|
66
|
+
# self,
|
|
67
|
+
# data_shape: tuple[int, ...],
|
|
68
|
+
# tile_size: tuple[int, ...],
|
|
69
|
+
# overlaps: tuple[int, ...],
|
|
70
|
+
# trim_boundary: tuple[int, ...],
|
|
71
|
+
# ):
|
|
72
|
+
# # --- validation
|
|
73
|
+
# if len(data_shape) != len(tile_size):
|
|
74
|
+
# raise ValueError(
|
|
75
|
+
# f"Data shape:{data_shape} and tile size:{tile_size} must have the "
|
|
76
|
+
# "same dimension"
|
|
77
|
+
# )
|
|
78
|
+
# if len(data_shape) != len(overlaps):
|
|
79
|
+
# raise ValueError(
|
|
80
|
+
# f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the "
|
|
81
|
+
# "same dimension"
|
|
82
|
+
# )
|
|
83
|
+
# # overlaps = np.array(tile_size) - np.array(grid_shape)
|
|
84
|
+
# if (np.array(overlaps) < 0).any():
|
|
85
|
+
# raise ValueError(
|
|
86
|
+
# "Tile overlap must be positive or zero in all dimension."
|
|
87
|
+
# )
|
|
88
|
+
# if ((np.array(overlaps) % 2) != 0).any():
|
|
89
|
+
# # TODO: currently not required by CAREamics tiling,
|
|
90
|
+
# # -> because floor divide is used.
|
|
91
|
+
# raise ValueError("Tile overlaps must be even.")
|
|
92
|
+
|
|
93
|
+
# # initialize attributes
|
|
94
|
+
# self.data_shape = data_shape
|
|
95
|
+
# self.overlaps = overlaps
|
|
96
|
+
# self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps))
|
|
97
|
+
# self.patch_shape = tile_size
|
|
98
|
+
# self.trim_boundary = trim_boundary
|
|
99
|
+
|
|
100
|
+
# def compute_tile_info(self, index: int, axes: str):
|
|
101
|
+
|
|
102
|
+
# # TODO: better axis validation, data should already be in the form SC(Z)YX
|
|
103
|
+
|
|
104
|
+
# # validate axes
|
|
105
|
+
# check_axes_validity(axes)
|
|
106
|
+
# # z will be -1 if not present
|
|
107
|
+
# spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")]
|
|
108
|
+
|
|
109
|
+
# # convert to numpy for convenience
|
|
110
|
+
# data_shape = np.array(self.data_shape)
|
|
111
|
+
# patch_shape = np.array(self.patch_shape)
|
|
112
|
+
|
|
113
|
+
# # --- calculate stitch coords
|
|
114
|
+
# stitch_coords_start = np.array(self.get_location_from_dataset_idx(index))
|
|
115
|
+
# stitch_coords_end = stitch_coords_start + np.array(self.grid_shape)
|
|
116
|
+
|
|
117
|
+
# # --- patch coords
|
|
118
|
+
# patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2
|
|
119
|
+
# patch_coords_end = patch_coords_start + patch_shape
|
|
120
|
+
|
|
121
|
+
# # --- replace out of bounds indices
|
|
122
|
+
|
|
123
|
+
# out_of_lower_bound = stitch_coords_start < 0
|
|
124
|
+
# out_of_upper_bound = stitch_coords_end > data_shape
|
|
125
|
+
|
|
126
|
+
# stitch_coords_start[out_of_lower_bound] = 0
|
|
127
|
+
# stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
|
|
128
|
+
|
|
129
|
+
# # --- calculate overlap crop coords
|
|
130
|
+
# overlap_crop_coords_start = stitch_coords_start - patch_coords_start
|
|
131
|
+
# overlap_crop_coords_end = overlap_crop_coords_start + (
|
|
132
|
+
# stitch_coords_end - stitch_coords_start
|
|
133
|
+
# )
|
|
134
|
+
|
|
135
|
+
# # --- combine start and end
|
|
136
|
+
# stitch_coords = tuple(
|
|
137
|
+
# (stitch_coords_start[axis], stitch_coords_end[axis])
|
|
138
|
+
# for axis in spatial_axes
|
|
139
|
+
# if axis != -1
|
|
140
|
+
# )
|
|
141
|
+
# overlap_crop_coords = tuple(
|
|
142
|
+
# (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis])
|
|
143
|
+
# for axis in spatial_axes
|
|
144
|
+
# if axis != -1
|
|
145
|
+
# )
|
|
146
|
+
|
|
147
|
+
# channel_axis = axes.find("C")
|
|
148
|
+
# array_shape_processed = tuple(
|
|
149
|
+
# data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1
|
|
150
|
+
# )
|
|
151
|
+
|
|
152
|
+
# tile_info = TileInformation(
|
|
153
|
+
# array_shape=array_shape_processed,
|
|
154
|
+
# last_tile=index == self.total_grid_count() - 1,
|
|
155
|
+
# overlap_crop_coords=overlap_crop_coords,
|
|
156
|
+
# stitch_coords=stitch_coords,
|
|
157
|
+
# sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure
|
|
158
|
+
# )
|
|
159
|
+
# return tile_info
|
|
160
|
+
|
|
161
|
+
# def patch_offset(self):
|
|
162
|
+
# return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
163
|
+
|
|
164
|
+
# def get_individual_dim_grid_count(self, dim: int):
|
|
165
|
+
# """
|
|
166
|
+
# Returns the number of the grid in the specified dimension, ignoring all other
|
|
167
|
+
# dimensions.
|
|
168
|
+
# """
|
|
169
|
+
# assert dim < len(
|
|
170
|
+
# self.data_shape
|
|
171
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
172
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
173
|
+
|
|
174
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
175
|
+
# return self.data_shape[dim]
|
|
176
|
+
# elif self.trim_boundary is False:
|
|
177
|
+
# return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
178
|
+
# else:
|
|
179
|
+
# excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
180
|
+
# return int(
|
|
181
|
+
# np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
182
|
+
# )
|
|
183
|
+
|
|
184
|
+
# def total_grid_count(self):
|
|
185
|
+
# """
|
|
186
|
+
# Returns the total number of grids in the dataset.
|
|
187
|
+
# """
|
|
188
|
+
# return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
189
|
+
|
|
190
|
+
# def grid_count(self, dim: int):
|
|
191
|
+
# """
|
|
192
|
+
# Returns the total number of grids for one value in the specified dimension.
|
|
193
|
+
# """
|
|
194
|
+
# assert dim < len(
|
|
195
|
+
# self.data_shape
|
|
196
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
197
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
198
|
+
# if dim == len(self.data_shape) - 1:
|
|
199
|
+
# return 1
|
|
200
|
+
|
|
201
|
+
# return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
202
|
+
|
|
203
|
+
# def get_grid_index(self, dim: int, coordinate: int):
|
|
204
|
+
# """
|
|
205
|
+
# Returns the index of the grid in the specified dimension.
|
|
206
|
+
# """
|
|
207
|
+
# assert dim < len(
|
|
208
|
+
# self.data_shape
|
|
209
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
210
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
211
|
+
# assert (
|
|
212
|
+
# coordinate < self.data_shape[dim]
|
|
213
|
+
# ), (
|
|
214
|
+
# f"Coordinate {coordinate} is out of bounds for data "
|
|
215
|
+
# f"shape {self.data_shape}"
|
|
216
|
+
# )
|
|
217
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
218
|
+
# return coordinate
|
|
219
|
+
# elif self.trim_boundary is False:
|
|
220
|
+
# return np.floor(coordinate / self.grid_shape[dim])
|
|
221
|
+
# else:
|
|
222
|
+
# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
223
|
+
# # can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
224
|
+
# return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
225
|
+
|
|
226
|
+
# def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
227
|
+
# """
|
|
228
|
+
# Returns the index of the grid in the dataset.
|
|
229
|
+
# """
|
|
230
|
+
# assert len(grid_idx) == len(
|
|
231
|
+
# self.data_shape
|
|
232
|
+
# ), (
|
|
233
|
+
# f"Dimension indices {grid_idx} must have the same dimension as data "
|
|
234
|
+
# f"shape {self.data_shape}"
|
|
235
|
+
# )
|
|
236
|
+
# index = 0
|
|
237
|
+
# for dim in range(len(grid_idx)):
|
|
238
|
+
# index += grid_idx[dim] * self.grid_count(dim)
|
|
239
|
+
# return index
|
|
240
|
+
|
|
241
|
+
# def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
242
|
+
# """
|
|
243
|
+
# Returns the patch location of the grid in the dataset.
|
|
244
|
+
# """
|
|
245
|
+
# location = self.get_location_from_dataset_idx(dataset_idx)
|
|
246
|
+
# offset = self.patch_offset()
|
|
247
|
+
# return tuple(np.array(location) - np.array(offset))
|
|
248
|
+
|
|
249
|
+
# def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
250
|
+
# assert len(location) == len(
|
|
251
|
+
# self.data_shape
|
|
252
|
+
# ), (
|
|
253
|
+
# f"Location {location} must have the same dimension as data shape "
|
|
254
|
+
# f"{self.data_shape}"
|
|
255
|
+
# )
|
|
256
|
+
# grid_idx = [
|
|
257
|
+
# self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
258
|
+
# ]
|
|
259
|
+
# return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
260
|
+
|
|
261
|
+
# def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
262
|
+
# """
|
|
263
|
+
# Returns the grid-start coordinate of the grid in the specified dimension.
|
|
264
|
+
# """
|
|
265
|
+
# assert dim < len(
|
|
266
|
+
# self.data_shape
|
|
267
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
268
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
269
|
+
# assert dim_index < self.get_individual_dim_grid_count(
|
|
270
|
+
# dim
|
|
271
|
+
# ), (
|
|
272
|
+
# f"Dimension index {dim_index} is out of bounds for data shape "
|
|
273
|
+
# f"{self.data_shape}"
|
|
274
|
+
# )
|
|
275
|
+
|
|
276
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
277
|
+
# return dim_index
|
|
278
|
+
# elif self.trim_boundary is False:
|
|
279
|
+
# return dim_index * self.grid_shape[dim]
|
|
280
|
+
# else:
|
|
281
|
+
# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
282
|
+
# return dim_index * self.grid_shape[dim] + excess_size
|
|
283
|
+
|
|
284
|
+
# def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
285
|
+
# grid_idx = []
|
|
286
|
+
# for dim in range(len(self.data_shape)):
|
|
287
|
+
# grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
288
|
+
# dataset_idx = dataset_idx % self.grid_count(dim)
|
|
289
|
+
# location = [
|
|
290
|
+
# self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
291
|
+
# for dim in range(len(self.data_shape))
|
|
292
|
+
# ]
|
|
293
|
+
# return tuple(location)
|
|
294
|
+
|
|
295
|
+
# def on_boundary(self, dataset_idx: int, dim: int):
|
|
296
|
+
# """
|
|
297
|
+
# Returns True if the grid is on the boundary in the specified dimension.
|
|
298
|
+
# """
|
|
299
|
+
# assert dim < len(
|
|
300
|
+
# self.data_shape
|
|
301
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
302
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
303
|
+
|
|
304
|
+
# if dim > 0:
|
|
305
|
+
# dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
306
|
+
|
|
307
|
+
# dim_index = dataset_idx // self.grid_count(dim)
|
|
308
|
+
# return (
|
|
309
|
+
# dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
310
|
+
# )
|
|
311
|
+
|
|
312
|
+
# def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
313
|
+
# """
|
|
314
|
+
# Returns the index of the grid in the specified dimension in the specified "
|
|
315
|
+
# "direction.
|
|
316
|
+
# """
|
|
317
|
+
# assert dim < len(
|
|
318
|
+
# self.data_shape
|
|
319
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
320
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
321
|
+
# new_idx = dataset_idx + self.grid_count(dim)
|
|
322
|
+
# if new_idx >= self.total_grid_count():
|
|
323
|
+
# return None
|
|
324
|
+
# return new_idx
|
|
325
|
+
|
|
326
|
+
# def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
327
|
+
# """
|
|
328
|
+
# Returns the index of the grid in the specified dimension in the specified "
|
|
329
|
+
# "direction.
|
|
330
|
+
# """
|
|
331
|
+
# assert dim < len(
|
|
332
|
+
# self.data_shape
|
|
333
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
334
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
335
|
+
# new_idx = dataset_idx - self.grid_count(dim)
|
|
336
|
+
# if new_idx < 0:
|
|
337
|
+
# return None
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# if __name__ == "__main__":
|
|
341
|
+
# data_shape = (1, 1, 103, 103, 2)
|
|
342
|
+
# grid_shape = (1, 1, 16, 16, 2)
|
|
343
|
+
# patch_shape = (1, 1, 32, 32, 2)
|
|
344
|
+
# overlap = tuple(np.array(patch_shape) - np.array(grid_shape))
|
|
345
|
+
|
|
346
|
+
# trim_boundary = False
|
|
347
|
+
# manager = TilingManager(
|
|
348
|
+
# data_shape=data_shape,
|
|
349
|
+
# tile_size=patch_shape,
|
|
350
|
+
# overlaps=overlap,
|
|
351
|
+
# trim_boundary=trim_boundary,
|
|
352
|
+
# )
|
|
353
|
+
# gc = manager.total_grid_count()
|
|
354
|
+
# print("Grid count", gc)
|
|
355
|
+
# for i in range(gc):
|
|
356
|
+
# loc = manager.get_location_from_dataset_idx(i)
|
|
357
|
+
# print(i, loc)
|
|
358
|
+
# inferred_i = manager.get_dataset_idx_from_grid_location(loc)
|
|
359
|
+
# assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
|
|
360
|
+
|
|
361
|
+
# for i in range(5):
|
|
362
|
+
# print(manager.on_boundary(40, i))
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, Tuple, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(predictions: List[Any], tiled: bool) -> list[NDArray]:
|
|
13
|
+
"""
|
|
14
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
15
|
+
|
|
16
|
+
This method allows stitching back together tiled predictions.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
predictions : list
|
|
21
|
+
Predictions that are output from `Trainer.predict`.
|
|
22
|
+
tiled : bool
|
|
23
|
+
Whether the predictions are tiled.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
list of numpy.ndarray or numpy.ndarray
|
|
28
|
+
List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
|
+
be in a list.
|
|
30
|
+
"""
|
|
31
|
+
if len(predictions) == 0:
|
|
32
|
+
return predictions
|
|
33
|
+
|
|
34
|
+
# this layout is to stop mypy complaining
|
|
35
|
+
if tiled:
|
|
36
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
37
|
+
predictions_output = stitch_prediction(*predictions_comb)
|
|
38
|
+
else:
|
|
39
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
40
|
+
|
|
41
|
+
return predictions_output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# for mypy
|
|
45
|
+
@overload
|
|
46
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
47
|
+
predictions: List[Any], tiled: Literal[True]
|
|
48
|
+
) -> Tuple[List[NDArray], List[TileInformation]]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# for mypy
|
|
52
|
+
@overload
|
|
53
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
54
|
+
predictions: List[Any], tiled: Literal[False]
|
|
55
|
+
) -> List[NDArray]: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# for mypy
|
|
59
|
+
@overload
|
|
60
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
61
|
+
predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
62
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def combine_batches(
|
|
66
|
+
predictions: List[Any], tiled: bool
|
|
67
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
|
|
68
|
+
"""
|
|
69
|
+
If predictions are in batches, they will be combined.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
predictions : list
|
|
74
|
+
Predictions that are output from `Trainer.predict`.
|
|
75
|
+
tiled : bool
|
|
76
|
+
Whether the predictions are tiled.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
81
|
+
Combined batches.
|
|
82
|
+
"""
|
|
83
|
+
if tiled:
|
|
84
|
+
return _combine_tiled_batches(predictions)
|
|
85
|
+
else:
|
|
86
|
+
return _combine_array_batches(predictions)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _combine_tiled_batches(
|
|
90
|
+
predictions: List[Tuple[NDArray, List[TileInformation]]]
|
|
91
|
+
) -> Tuple[List[NDArray], List[TileInformation]]:
|
|
92
|
+
"""
|
|
93
|
+
Combine batches from tiled output.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
predictions : list of (numpy.ndarray, list of TileInformation)
|
|
98
|
+
Predictions that are output from `Trainer.predict`. For tiled batches, this is
|
|
99
|
+
a list of tuples. The first element of the tuples is the prediction output of
|
|
100
|
+
tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
|
|
101
|
+
element of the tuples is a list of TileInformation objects of length B.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
106
|
+
Combined batches.
|
|
107
|
+
"""
|
|
108
|
+
# turn list of lists into single list
|
|
109
|
+
tile_infos = [
|
|
110
|
+
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
111
|
+
]
|
|
112
|
+
prediction_tiles: List[NDArray] = _combine_array_batches(
|
|
113
|
+
[preds for preds, _ in predictions]
|
|
114
|
+
)
|
|
115
|
+
return prediction_tiles, tile_infos
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
|
|
119
|
+
"""
|
|
120
|
+
Combine batches of arrays.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
predictions : list
|
|
125
|
+
Prediction arrays that are output from `Trainer.predict`. A list of arrays that
|
|
126
|
+
have dimensions (B, C, (Z), Y, X), where B is batch size.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
list of numpy.ndarray
|
|
131
|
+
A list of arrays with dimensions (1, C, (Z), Y, X).
|
|
132
|
+
"""
|
|
133
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
134
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
135
|
+
return prediction_split
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.tile_information import TileInformation
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
13
|
+
def stitch_prediction(
|
|
14
|
+
tiles: List[np.ndarray],
|
|
15
|
+
tile_infos: List[TileInformation],
|
|
16
|
+
) -> List[np.ndarray]:
|
|
17
|
+
"""
|
|
18
|
+
Stitch tiles back together to form a full image(s).
|
|
19
|
+
|
|
20
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
21
|
+
singleton dimension.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
tiles : list of numpy.ndarray
|
|
26
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
27
|
+
from multiple images.
|
|
28
|
+
tile_infos : list of TileInformation
|
|
29
|
+
List of information and coordinates obtained from
|
|
30
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
list of numpy.ndarray
|
|
35
|
+
Full image(s).
|
|
36
|
+
"""
|
|
37
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
38
|
+
# Do this by locating the last tiles of each image.
|
|
39
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
40
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
41
|
+
image_slices = [
|
|
42
|
+
slice(
|
|
43
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
44
|
+
)
|
|
45
|
+
for i in range(len(last_tile_position))
|
|
46
|
+
]
|
|
47
|
+
image_predictions = []
|
|
48
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
49
|
+
for image_slice in image_slices:
|
|
50
|
+
image_predictions.append(
|
|
51
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
52
|
+
)
|
|
53
|
+
return image_predictions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def stitch_prediction_single(
|
|
57
|
+
tiles: List[NDArray],
|
|
58
|
+
tile_infos: List[TileInformation],
|
|
59
|
+
) -> NDArray:
|
|
60
|
+
"""
|
|
61
|
+
Stitch tiles back together to form a full image.
|
|
62
|
+
|
|
63
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
64
|
+
singleton dimension.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
tiles : list of numpy.ndarray
|
|
69
|
+
Cropped tiles and their respective stitching coordinates.
|
|
70
|
+
tile_infos : list of TileInformation
|
|
71
|
+
List of information and coordinates obtained from
|
|
72
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
numpy.ndarray
|
|
77
|
+
Full image, with dimensions SC(Z)YX.
|
|
78
|
+
"""
|
|
79
|
+
# TODO: this is hacky... need a better way to deal with when input channels and
|
|
80
|
+
# target channels do not match
|
|
81
|
+
if len(tile_infos[0].array_shape) == 4:
|
|
82
|
+
# 4 dimensions => 3 spatial dimensions so -4 is channel dimension
|
|
83
|
+
tile_channels = tiles[0].shape[-4]
|
|
84
|
+
elif len(tile_infos[0].array_shape) == 3:
|
|
85
|
+
# 3 dimensions => 2 spatial dimensions so -3 is channel dimension
|
|
86
|
+
tile_channels = tiles[0].shape[-3]
|
|
87
|
+
else:
|
|
88
|
+
# Note pretty sure this is unreachable because array shape is already
|
|
89
|
+
# validated by TileInformation
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Unsupported number of output dimension {len(tile_infos[0].array_shape)}"
|
|
92
|
+
)
|
|
93
|
+
# retrieve whole array size, add S dim and use number of channels in tile
|
|
94
|
+
input_shape = (1, tile_channels, *tile_infos[0].array_shape[1:])
|
|
95
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
96
|
+
|
|
97
|
+
for tile, tile_info in zip(tiles, tile_infos):
|
|
98
|
+
|
|
99
|
+
# Compute coordinates for cropping predicted tile
|
|
100
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
101
|
+
...,
|
|
102
|
+
*[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Crop predited tile according to overlap coordinates
|
|
106
|
+
cropped_tile = tile[crop_slices]
|
|
107
|
+
|
|
108
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
109
|
+
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
110
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32)
|
|
111
|
+
|
|
112
|
+
return predicted_image
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Transforms that are used to augment the data."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"get_all_transforms",
|
|
5
|
+
"N2VManipulate",
|
|
6
|
+
"XYFlip",
|
|
7
|
+
"XYRandomRotate90",
|
|
8
|
+
"ImageRestorationTTA",
|
|
9
|
+
"Denormalize",
|
|
10
|
+
"Normalize",
|
|
11
|
+
"Compose",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .compose import Compose, get_all_transforms
|
|
16
|
+
from .n2v_manipulate import N2VManipulate
|
|
17
|
+
from .normalize import Denormalize, Normalize
|
|
18
|
+
from .tta import ImageRestorationTTA
|
|
19
|
+
from .xy_flip import XYFlip
|
|
20
|
+
from .xy_random_rotate90 import XYRandomRotate90
|