careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -219,7 +219,7 @@ class StableLogVar:
219
219
  self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
220
220
  ):
221
221
  """
222
- Contructor.
222
+ Constructor.
223
223
 
224
224
  Parameters
225
225
  ----------
@@ -295,7 +295,7 @@ class StableMean:
295
295
 
296
296
  def allow_numpy(func):
297
297
  """
298
- All optional arguements are passed as is. positional arguments are checked. if they are numpy array,
298
+ All optional arguments are passed as is. positional arguments are checked. if they are numpy array,
299
299
  they are converted to torch Tensor.
300
300
  """
301
301
 
@@ -4,20 +4,34 @@ Model factory.
4
4
  Model creation factory functions.
5
5
  """
6
6
 
7
- from typing import Union
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Union
8
10
 
9
11
  import torch
10
12
 
11
- from ..config.architectures import CustomModel, UNetModel, VAEModel, get_custom_model
12
- from ..config.support import SupportedArchitecture
13
- from ..utils import get_logger
14
- from .unet import UNet
13
+ from careamics.config.architectures import (
14
+ CustomModel,
15
+ get_custom_model,
16
+ )
17
+ from careamics.config.support import SupportedArchitecture
18
+ from careamics.models.lvae import LadderVAE as LVAE
19
+ from careamics.models.unet import UNet
20
+ from careamics.utils import get_logger
21
+
22
+ if TYPE_CHECKING:
23
+ from careamics.config.architectures import (
24
+ CustomModel,
25
+ LVAEModel,
26
+ UNetModel,
27
+ )
28
+
15
29
 
16
30
  logger = get_logger(__name__)
17
31
 
18
32
 
19
33
  def model_factory(
20
- model_configuration: Union[UNetModel, VAEModel, CustomModel]
34
+ model_configuration: Union[UNetModel, LVAEModel, CustomModel],
21
35
  ) -> torch.nn.Module:
22
36
  """
23
37
  Deep learning model factory.
@@ -41,10 +55,11 @@ def model_factory(
41
55
  """
42
56
  if model_configuration.architecture == SupportedArchitecture.UNET:
43
57
  return UNet(**model_configuration.model_dump())
58
+ elif model_configuration.architecture == SupportedArchitecture.LVAE:
59
+ return LVAE(**model_configuration.model_dump())
44
60
  elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
45
61
  assert isinstance(model_configuration, CustomModel)
46
62
  model = get_custom_model(model_configuration.name)
47
-
48
63
  return model(**model_configuration.model_dump())
49
64
  else:
50
65
  raise NotImplementedError(
@@ -0,0 +1,158 @@
1
+ """Module containing pytorch implementations for obtaining predictions from an LVAE."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+
7
+ from careamics.models.lvae import LadderVAE as LVAE
8
+ from careamics.models.lvae.likelihoods import LikelihoodModule
9
+
10
+ # TODO: convert these functions to lightning module `predict_step`
11
+ # -> mmse_count will have to be an instance attribute?
12
+
13
+
14
+ # This function is needed because the output of the datasets (input here) can include
15
+ # auxillary items, such as the TileInformation. This function allows for easier reuse
16
+ # between lvae_predict_single_sample and lvae_predict_mmse.
17
+ def lvae_predict_single_sample(
18
+ model: LVAE,
19
+ likelihood_obj: LikelihoodModule,
20
+ input: torch.Tensor,
21
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
22
+ """
23
+ Generate a single sample prediction from an LVAE model, for a given input.
24
+
25
+ Parameters
26
+ ----------
27
+ model : LVAE
28
+ Trained LVAE model.
29
+ likelihood_obj : LikelihoodModule
30
+ Instance of a likelihood class.
31
+ input : torch.tensor
32
+ Input to generate prediction for. Expected shape is (S, C, Y, X).
33
+
34
+ Returns
35
+ -------
36
+ tuple of (torch.tensor, optional torch.tensor)
37
+ The first element is the sample prediction, and the second element is the
38
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
39
+ """
40
+ model.eval() # Not in original predict code: effects batch_norm and dropout layers
41
+ with torch.no_grad():
42
+ output: torch.Tensor
43
+ output, _ = model(input) # 2nd item is top-down data dict
44
+
45
+ # presently, get_mean_lv just splits the output in 2 if predict_logvar=True,
46
+ # optionally clips the logvavr if logvar_lowerbound is not None
47
+ # TODO: consider refactoring to remove use of the likelihood object
48
+ sample_prediction, log_var = likelihood_obj.get_mean_lv(output)
49
+
50
+ # TODO: output denormalization using target stats that will be saved in data config
51
+ # -> Don't think we need this, saw it in a random bit of code somewhere.
52
+
53
+ return sample_prediction, log_var
54
+
55
+
56
+ def lvae_predict_tiled_batch(
57
+ model: LVAE,
58
+ likelihood_obj: LikelihoodModule,
59
+ input: tuple[Any],
60
+ ) -> tuple[tuple[Any], Optional[tuple[Any]]]:
61
+ # TODO: fix docstring return types, ... too many output options
62
+ """
63
+ Generate a single sample prediction from an LVAE model, for a given input.
64
+
65
+ Parameters
66
+ ----------
67
+ model : LVAE
68
+ Trained LVAE model.
69
+ likelihood_obj : LikelihoodModule
70
+ Instance of a likelihood class.
71
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
72
+ Input to generate prediction for. This can include auxilary inputs such as
73
+ `TileInformation`, but the model input is always the first item of the tuple.
74
+ Expected shape of the model input is (S, C, Y, X).
75
+
76
+ Returns
77
+ -------
78
+ tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
79
+ The first element is the sample prediction, and the second element is the
80
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
81
+ Any auxillary data included in the input will also be include with both the
82
+ sample prediction and the log-variance.
83
+ """
84
+ x: torch.Tensor
85
+ aux: list[Any]
86
+ x, *aux = input
87
+
88
+ sample_prediction, log_var = lvae_predict_single_sample(
89
+ model=model, likelihood_obj=likelihood_obj, input=x
90
+ )
91
+
92
+ log_var_output = (log_var, *aux) if log_var is not None else None
93
+ return (sample_prediction, *aux), log_var_output
94
+
95
+
96
+ def lvae_predict_mmse_tiled_batch(
97
+ model: LVAE,
98
+ likelihood_obj: LikelihoodModule,
99
+ input: tuple[Any],
100
+ mmse_count: int,
101
+ ) -> tuple[tuple[Any], tuple[Any], Optional[tuple[Any]]]:
102
+ # TODO: fix docstring return types, ... hard to make readable
103
+ """
104
+ Generate the MMSE (minimum mean squared error) prediction, for a given input.
105
+
106
+ This is calculated from the mean of multiple single sample predictions.
107
+
108
+ Parameters
109
+ ----------
110
+ model : LVAE
111
+ Trained LVAE model.
112
+ likelihood_obj : LikelihoodModule
113
+ Instance of a likelihood class.
114
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
115
+ Input to generate prediction for. This can include auxilary inputs such as
116
+ `TileInformation`, but the model input is always the first item of the tuple.
117
+ Expected shape of the model input is (S, C, Y, X).
118
+ mmse_count : int
119
+ Number of samples to generate to calculate MMSE (minimum mean squared error).
120
+
121
+ Returns
122
+ -------
123
+ tuple of (tuple of (torch.Tensor[Any], Any, ...))
124
+ A tuple of 3 elements. The first element contains the MMSE prediction, the
125
+ second contains the standard deviation of the samples used to create the MMSE
126
+ prediction. Finally the last element contains the log-variance of the
127
+ likelihood, this will be `None` if `likelihood.predict_logvar` is `None`.
128
+ Any auxillary data included in the input will also be include with all of the
129
+ MMSE prediction, the standard deviation, and the log-variance.
130
+ """
131
+ if mmse_count <= 0:
132
+ raise ValueError("MMSE count must be greater than zero.")
133
+
134
+ x: torch.Tensor
135
+ aux: list[Any]
136
+ x, *aux = input
137
+
138
+ input_shape = x.shape
139
+ output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
140
+ log_var: Optional[torch.Tensor] = None
141
+ # pre-declare empty array to fill with individual sample predictions
142
+ sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
143
+ for mmse_idx in range(mmse_count):
144
+ sample_prediction, lv = lvae_predict_single_sample(
145
+ model=model, likelihood_obj=likelihood_obj, input=x
146
+ )
147
+ # only keep the log variance of the first sample prediction
148
+ if mmse_idx == 0:
149
+ log_var = lv
150
+
151
+ # store sample predictions
152
+ sample_predictions[mmse_idx, ...] = sample_prediction
153
+
154
+ mmse_prediction = torch.mean(sample_predictions, dim=0)
155
+ mmse_prediction_std = torch.std(sample_predictions, dim=0)
156
+
157
+ log_var_output = (log_var, *aux) if log_var is not None else None
158
+ return (mmse_prediction, *aux), (mmse_prediction_std, *aux), log_var_output
@@ -0,0 +1,362 @@
1
+ """Module contiaing tiling manager class."""
2
+
3
+ # # TODO: remove this file, left as a reference for now.
4
+
5
+ # from typing import Any, Optional
6
+
7
+ # import numpy as np
8
+ # from numpy.typing import NDArray
9
+
10
+ # from careamics.config.tile_information import TileInformation
11
+ # from careamics.config.validators import check_axes_validity
12
+
13
+
14
+ # def calculate_padding(
15
+ # patch_start_location: NDArray,
16
+ # patch_size: NDArray,
17
+ # data_shape: NDArray,
18
+ # ) -> NDArray:
19
+ # patch_end_location = patch_start_location + patch_size
20
+
21
+ # pad_before = np.zeros_like(patch_start_location)
22
+ # start_out_of_bounds = patch_start_location < 0
23
+ # pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds]
24
+
25
+ # pad_after = np.zeros_like(patch_start_location)
26
+ # end_out_of_bounds = patch_end_location > data_shape
27
+ # pad_after[end_out_of_bounds] = (
28
+ # patch_end_location - data_shape
29
+ # )[end_out_of_bounds]
30
+
31
+ # return np.stack([pad_before, pad_after], axis=1)
32
+
33
+
34
+ # def extract_tile(
35
+ # img: np.ndarray,
36
+ # grid_start_loc: tuple[int, ...],
37
+ # patch_size: tuple[int, ...],
38
+ # overlap: tuple[int, ...],
39
+ # padding: bool,
40
+ # padding_kwargs: Optional[dict[str, Any]] = None,
41
+ # ) -> NDArray:
42
+ # if padding_kwargs is None:
43
+ # padding_kwargs = {}
44
+
45
+ # data_shape = img.shape
46
+ # patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2
47
+ # crop_slices = tuple(
48
+ # slice(max(0, start), min(start + size, dim_shape))
49
+ # for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape)
50
+ # )
51
+ # crop = img[crop_slices]
52
+ # if padding:
53
+ # pad = calculate_padding(
54
+ # patch_start_location=patch_start_loc,
55
+ # patch_size=patch_size,
56
+ # data_shape=data_shape,
57
+ # )
58
+ # crop = np.pad(crop, pad, **padding_kwargs)
59
+
60
+ # return crop
61
+
62
+
63
+ # class TilingManager:
64
+
65
+ # def __init__(
66
+ # self,
67
+ # data_shape: tuple[int, ...],
68
+ # tile_size: tuple[int, ...],
69
+ # overlaps: tuple[int, ...],
70
+ # trim_boundary: tuple[int, ...],
71
+ # ):
72
+ # # --- validation
73
+ # if len(data_shape) != len(tile_size):
74
+ # raise ValueError(
75
+ # f"Data shape:{data_shape} and tile size:{tile_size} must have the "
76
+ # "same dimension"
77
+ # )
78
+ # if len(data_shape) != len(overlaps):
79
+ # raise ValueError(
80
+ # f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the "
81
+ # "same dimension"
82
+ # )
83
+ # # overlaps = np.array(tile_size) - np.array(grid_shape)
84
+ # if (np.array(overlaps) < 0).any():
85
+ # raise ValueError(
86
+ # "Tile overlap must be positive or zero in all dimension."
87
+ # )
88
+ # if ((np.array(overlaps) % 2) != 0).any():
89
+ # # TODO: currently not required by CAREamics tiling,
90
+ # # -> because floor divide is used.
91
+ # raise ValueError("Tile overlaps must be even.")
92
+
93
+ # # initialize attributes
94
+ # self.data_shape = data_shape
95
+ # self.overlaps = overlaps
96
+ # self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps))
97
+ # self.patch_shape = tile_size
98
+ # self.trim_boundary = trim_boundary
99
+
100
+ # def compute_tile_info(self, index: int, axes: str):
101
+
102
+ # # TODO: better axis validation, data should already be in the form SC(Z)YX
103
+
104
+ # # validate axes
105
+ # check_axes_validity(axes)
106
+ # # z will be -1 if not present
107
+ # spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")]
108
+
109
+ # # convert to numpy for convenience
110
+ # data_shape = np.array(self.data_shape)
111
+ # patch_shape = np.array(self.patch_shape)
112
+
113
+ # # --- calculate stitch coords
114
+ # stitch_coords_start = np.array(self.get_location_from_dataset_idx(index))
115
+ # stitch_coords_end = stitch_coords_start + np.array(self.grid_shape)
116
+
117
+ # # --- patch coords
118
+ # patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2
119
+ # patch_coords_end = patch_coords_start + patch_shape
120
+
121
+ # # --- replace out of bounds indices
122
+
123
+ # out_of_lower_bound = stitch_coords_start < 0
124
+ # out_of_upper_bound = stitch_coords_end > data_shape
125
+
126
+ # stitch_coords_start[out_of_lower_bound] = 0
127
+ # stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
128
+
129
+ # # --- calculate overlap crop coords
130
+ # overlap_crop_coords_start = stitch_coords_start - patch_coords_start
131
+ # overlap_crop_coords_end = overlap_crop_coords_start + (
132
+ # stitch_coords_end - stitch_coords_start
133
+ # )
134
+
135
+ # # --- combine start and end
136
+ # stitch_coords = tuple(
137
+ # (stitch_coords_start[axis], stitch_coords_end[axis])
138
+ # for axis in spatial_axes
139
+ # if axis != -1
140
+ # )
141
+ # overlap_crop_coords = tuple(
142
+ # (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis])
143
+ # for axis in spatial_axes
144
+ # if axis != -1
145
+ # )
146
+
147
+ # channel_axis = axes.find("C")
148
+ # array_shape_processed = tuple(
149
+ # data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1
150
+ # )
151
+
152
+ # tile_info = TileInformation(
153
+ # array_shape=array_shape_processed,
154
+ # last_tile=index == self.total_grid_count() - 1,
155
+ # overlap_crop_coords=overlap_crop_coords,
156
+ # stitch_coords=stitch_coords,
157
+ # sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure
158
+ # )
159
+ # return tile_info
160
+
161
+ # def patch_offset(self):
162
+ # return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
163
+
164
+ # def get_individual_dim_grid_count(self, dim: int):
165
+ # """
166
+ # Returns the number of the grid in the specified dimension, ignoring all other
167
+ # dimensions.
168
+ # """
169
+ # assert dim < len(
170
+ # self.data_shape
171
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
172
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
173
+
174
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
175
+ # return self.data_shape[dim]
176
+ # elif self.trim_boundary is False:
177
+ # return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
178
+ # else:
179
+ # excess_size = self.patch_shape[dim] - self.grid_shape[dim]
180
+ # return int(
181
+ # np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
182
+ # )
183
+
184
+ # def total_grid_count(self):
185
+ # """
186
+ # Returns the total number of grids in the dataset.
187
+ # """
188
+ # return self.grid_count(0) * self.get_individual_dim_grid_count(0)
189
+
190
+ # def grid_count(self, dim: int):
191
+ # """
192
+ # Returns the total number of grids for one value in the specified dimension.
193
+ # """
194
+ # assert dim < len(
195
+ # self.data_shape
196
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
197
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
198
+ # if dim == len(self.data_shape) - 1:
199
+ # return 1
200
+
201
+ # return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
202
+
203
+ # def get_grid_index(self, dim: int, coordinate: int):
204
+ # """
205
+ # Returns the index of the grid in the specified dimension.
206
+ # """
207
+ # assert dim < len(
208
+ # self.data_shape
209
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
210
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
211
+ # assert (
212
+ # coordinate < self.data_shape[dim]
213
+ # ), (
214
+ # f"Coordinate {coordinate} is out of bounds for data "
215
+ # f"shape {self.data_shape}"
216
+ # )
217
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
218
+ # return coordinate
219
+ # elif self.trim_boundary is False:
220
+ # return np.floor(coordinate / self.grid_shape[dim])
221
+ # else:
222
+ # excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
223
+ # # can be <0 if coordinate is in [0,grid_shape[dim]]
224
+ # return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
225
+
226
+ # def dataset_idx_from_grid_idx(self, grid_idx: tuple):
227
+ # """
228
+ # Returns the index of the grid in the dataset.
229
+ # """
230
+ # assert len(grid_idx) == len(
231
+ # self.data_shape
232
+ # ), (
233
+ # f"Dimension indices {grid_idx} must have the same dimension as data "
234
+ # f"shape {self.data_shape}"
235
+ # )
236
+ # index = 0
237
+ # for dim in range(len(grid_idx)):
238
+ # index += grid_idx[dim] * self.grid_count(dim)
239
+ # return index
240
+
241
+ # def get_patch_location_from_dataset_idx(self, dataset_idx: int):
242
+ # """
243
+ # Returns the patch location of the grid in the dataset.
244
+ # """
245
+ # location = self.get_location_from_dataset_idx(dataset_idx)
246
+ # offset = self.patch_offset()
247
+ # return tuple(np.array(location) - np.array(offset))
248
+
249
+ # def get_dataset_idx_from_grid_location(self, location: tuple):
250
+ # assert len(location) == len(
251
+ # self.data_shape
252
+ # ), (
253
+ # f"Location {location} must have the same dimension as data shape "
254
+ # f"{self.data_shape}"
255
+ # )
256
+ # grid_idx = [
257
+ # self.get_grid_index(dim, location[dim]) for dim in range(len(location))
258
+ # ]
259
+ # return self.dataset_idx_from_grid_idx(tuple(grid_idx))
260
+
261
+ # def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
262
+ # """
263
+ # Returns the grid-start coordinate of the grid in the specified dimension.
264
+ # """
265
+ # assert dim < len(
266
+ # self.data_shape
267
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
268
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
269
+ # assert dim_index < self.get_individual_dim_grid_count(
270
+ # dim
271
+ # ), (
272
+ # f"Dimension index {dim_index} is out of bounds for data shape "
273
+ # f"{self.data_shape}"
274
+ # )
275
+
276
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
277
+ # return dim_index
278
+ # elif self.trim_boundary is False:
279
+ # return dim_index * self.grid_shape[dim]
280
+ # else:
281
+ # excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
282
+ # return dim_index * self.grid_shape[dim] + excess_size
283
+
284
+ # def get_location_from_dataset_idx(self, dataset_idx: int):
285
+ # grid_idx = []
286
+ # for dim in range(len(self.data_shape)):
287
+ # grid_idx.append(dataset_idx // self.grid_count(dim))
288
+ # dataset_idx = dataset_idx % self.grid_count(dim)
289
+ # location = [
290
+ # self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
291
+ # for dim in range(len(self.data_shape))
292
+ # ]
293
+ # return tuple(location)
294
+
295
+ # def on_boundary(self, dataset_idx: int, dim: int):
296
+ # """
297
+ # Returns True if the grid is on the boundary in the specified dimension.
298
+ # """
299
+ # assert dim < len(
300
+ # self.data_shape
301
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
302
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
303
+
304
+ # if dim > 0:
305
+ # dataset_idx = dataset_idx % self.grid_count(dim - 1)
306
+
307
+ # dim_index = dataset_idx // self.grid_count(dim)
308
+ # return (
309
+ # dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
310
+ # )
311
+
312
+ # def next_grid_along_dim(self, dataset_idx: int, dim: int):
313
+ # """
314
+ # Returns the index of the grid in the specified dimension in the specified "
315
+ # "direction.
316
+ # """
317
+ # assert dim < len(
318
+ # self.data_shape
319
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
320
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
321
+ # new_idx = dataset_idx + self.grid_count(dim)
322
+ # if new_idx >= self.total_grid_count():
323
+ # return None
324
+ # return new_idx
325
+
326
+ # def prev_grid_along_dim(self, dataset_idx: int, dim: int):
327
+ # """
328
+ # Returns the index of the grid in the specified dimension in the specified "
329
+ # "direction.
330
+ # """
331
+ # assert dim < len(
332
+ # self.data_shape
333
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
334
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
335
+ # new_idx = dataset_idx - self.grid_count(dim)
336
+ # if new_idx < 0:
337
+ # return None
338
+
339
+
340
+ # if __name__ == "__main__":
341
+ # data_shape = (1, 1, 103, 103, 2)
342
+ # grid_shape = (1, 1, 16, 16, 2)
343
+ # patch_shape = (1, 1, 32, 32, 2)
344
+ # overlap = tuple(np.array(patch_shape) - np.array(grid_shape))
345
+
346
+ # trim_boundary = False
347
+ # manager = TilingManager(
348
+ # data_shape=data_shape,
349
+ # tile_size=patch_shape,
350
+ # overlaps=overlap,
351
+ # trim_boundary=trim_boundary,
352
+ # )
353
+ # gc = manager.total_grid_count()
354
+ # print("Grid count", gc)
355
+ # for i in range(gc):
356
+ # loc = manager.get_location_from_dataset_idx(i)
357
+ # print(i, loc)
358
+ # inferred_i = manager.get_dataset_idx_from_grid_location(loc)
359
+ # assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
360
+
361
+ # for i in range(5):
362
+ # print(manager.on_boundary(40, i))
@@ -76,8 +76,22 @@ def stitch_prediction_single(
76
76
  numpy.ndarray
77
77
  Full image, with dimensions SC(Z)YX.
78
78
  """
79
- # retrieve whole array size
80
- input_shape = (1, *tile_infos[0].array_shape) # add S dim
79
+ # TODO: this is hacky... need a better way to deal with when input channels and
80
+ # target channels do not match
81
+ if len(tile_infos[0].array_shape) == 4:
82
+ # 4 dimensions => 3 spatial dimensions so -4 is channel dimension
83
+ tile_channels = tiles[0].shape[-4]
84
+ elif len(tile_infos[0].array_shape) == 3:
85
+ # 3 dimensions => 2 spatial dimensions so -3 is channel dimension
86
+ tile_channels = tiles[0].shape[-3]
87
+ else:
88
+ # Note pretty sure this is unreachable because array shape is already
89
+ # validated by TileInformation
90
+ raise ValueError(
91
+ f"Unsupported number of output dimension {len(tile_infos[0].array_shape)}"
92
+ )
93
+ # retrieve whole array size, add S dim and use number of channels in tile
94
+ input_shape = (1, tile_channels, *tile_infos[0].array_shape[1:])
81
95
  predicted_image = np.zeros(input_shape, dtype=np.float32)
82
96
 
83
97
  for tile, tile_info in zip(tiles, tile_infos):
@@ -161,7 +161,7 @@ def _get_stratified_coords(
161
161
  coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
162
162
 
163
163
  grid_random_increment = rng.integers(
164
- _odd_jitter_func(float(max(steps)), rng)
164
+ _odd_jitter_func(float(max(steps)), rng) # type: ignore
165
165
  * np.ones_like(coordinate_grid).astype(np.int32)
166
166
  - 1,
167
167
  size=coordinate_grid.shape,