careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,43 @@
1
+ """Likelihood model."""
2
+
3
+ from typing import Literal, Optional, Union
4
+
5
+ import torch
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from careamics.models.lvae.noise_models import (
9
+ GaussianMixtureNoiseModel,
10
+ MultiChannelNoiseModel,
11
+ )
12
+
13
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
14
+
15
+
16
+ class GaussianLikelihoodConfig(BaseModel):
17
+ """Gaussian likelihood configuration."""
18
+
19
+ model_config = ConfigDict(validate_assignment=True)
20
+
21
+ predict_logvar: Optional[Literal["pixelwise"]] = None
22
+ """If `pixelwise`, log-variance is computed for each pixel, else log-variance
23
+ is not computed."""
24
+
25
+ logvar_lowerbound: Union[float, None] = None
26
+ """The lowerbound value for log-variance."""
27
+
28
+
29
+ class NMLikelihoodConfig(BaseModel):
30
+ """Noise model likelihood configuration."""
31
+
32
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
33
+
34
+ data_mean: Union[torch.Tensor] = torch.zeros(1)
35
+ """The mean of the data, used to unnormalize data for noise model evaluation.
36
+ Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
37
+
38
+ data_std: Union[torch.Tensor] = torch.ones(1)
39
+ """The standard deviation of the data, used to unnormalize data for noise
40
+ model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
41
+
42
+ noise_model: Union[NoiseModel, None] = None
43
+ """The noise model instance used to compute the likelihood."""
@@ -0,0 +1,101 @@
1
+ """Noise models config."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from typing_extensions import Self
9
+
10
+ # TODO: add histogram-based noise model
11
+
12
+
13
+ class GaussianMixtureNMConfig(BaseModel):
14
+ """Gaussian mixture noise model."""
15
+
16
+ model_config = ConfigDict(
17
+ protected_namespaces=(),
18
+ validate_assignment=True,
19
+ arbitrary_types_allowed=True,
20
+ extra="allow",
21
+ )
22
+ # model type
23
+ model_type: Literal["GaussianMixtureNoiseModel"]
24
+
25
+ path: Optional[Union[Path, str]] = None
26
+ """Path to the directory where the trained noise model (*.npz) is saved in the
27
+ `train` method."""
28
+
29
+ signal: Optional[Union[str, Path, np.ndarray]] = None
30
+ """Path to the file containing signal or respective numpy array."""
31
+
32
+ observation: Optional[Union[str, Path, np.ndarray]] = None
33
+ """Path to the file containing observation or respective numpy array."""
34
+
35
+ weight: Optional[np.ndarray] = None
36
+ """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
37
+ describing the GMM noise model, with each row corresponding to one
38
+ parameter of each gaussian, namely [mean, standard deviation and weight].
39
+ Specifically, rows are organized as follows:
40
+ - first n_gaussian rows correspond to the means
41
+ - next n_gaussian rows correspond to the weights
42
+ - last n_gaussian rows correspond to the standard deviations
43
+ If `weight=None`, the weight array is initialized using the `min_signal`
44
+ and `max_signal` parameters."""
45
+
46
+ n_gaussian: int = Field(default=1, ge=1)
47
+ """Number of gaussians used for the GMM."""
48
+
49
+ n_coeff: int = Field(default=2, ge=2)
50
+ """Number of coefficients to describe the functional relationship between gaussian
51
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
52
+ relationship and so on."""
53
+
54
+ min_signal: float = Field(default=0.0, ge=0.0)
55
+ """Minimum signal intensity expected in the image."""
56
+
57
+ max_signal: float = Field(default=1.0, ge=0.0)
58
+ """Maximum signal intensity expected in the image."""
59
+
60
+ min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
61
+ """Minimum value of `standard deviation` allowed in the GMM.
62
+ All values of `standard deviation` below this are clamped to this value."""
63
+
64
+ tol: float = Field(default=1e-10)
65
+ """Tolerance used in the computation of the noise model likelihood."""
66
+
67
+ @model_validator(mode="after")
68
+ def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
69
+ """Validate paths provided in the config.
70
+
71
+ Returns
72
+ -------
73
+ Self
74
+ Returns itself.
75
+ """
76
+ if self.path and (self.signal is not None or self.observation is not None):
77
+ raise ValueError(
78
+ "Either only 'path' to pre-trained noise model should be"
79
+ "provided or only signal and observation in form of paths"
80
+ "or numpy arrays."
81
+ )
82
+ if not self.path and (self.signal is None or self.observation is None):
83
+ raise ValueError(
84
+ "Either only 'path' to pre-trained noise model should be"
85
+ "provided or only signal and observation in form of paths"
86
+ "or numpy arrays."
87
+ )
88
+ return self
89
+
90
+
91
+ # The noise model is given by a set of GMMs, one for each target
92
+ # e.g., 2 target channels, 2 noise models
93
+ class MultiChannelNMConfig(BaseModel):
94
+ """Noise Model config aggregating noise models for single output channels."""
95
+
96
+ # TODO: check that this model config is OK
97
+ model_config = ConfigDict(
98
+ validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
99
+ )
100
+ noise_models: list[GaussianMixtureNMConfig]
101
+ """List of noise models, one for each target channel."""
@@ -24,3 +24,4 @@ class SupportedActivation(str, BaseEnum):
24
24
  TANH = "Tanh"
25
25
  RELU = "ReLU"
26
26
  LEAKYRELU = "LeakyReLU"
27
+ ELU = "ELU"
@@ -6,15 +6,28 @@ from careamics.utils import BaseEnum
6
6
 
7
7
 
8
8
  class SupportedAlgorithm(str, BaseEnum):
9
- """Algorithms available in CAREamics.
10
-
11
- # TODO
12
- """
9
+ """Algorithms available in CAREamics."""
13
10
 
14
11
  N2V = "n2v"
12
+ """Noise2Void algorithm, a self-supervised approach based on blind denoising."""
13
+
15
14
  CARE = "care"
15
+ """Content-aware image restoration, a supervised algorithm used for a variety
16
+ of tasks."""
17
+
16
18
  N2N = "n2n"
19
+ """Noise2Noise algorithm, a self-supervised denoising scheme based on comparing
20
+ noisy images of the same sample."""
21
+
22
+ MUSPLIT = "musplit"
23
+ """An image splitting approach based on ladder VAE architectures."""
24
+
25
+ DENOISPLIT = "denoisplit"
26
+ """An image splitting and denoising approach based on ladder VAE architectures."""
27
+
17
28
  CUSTOM = "custom"
29
+ """Custom algorithm, used for cases where a custom architecture is provided."""
30
+
18
31
  # PN2V = "pn2v"
19
32
  # HDN = "hdn"
20
33
  # SEG = "segmentation"
@@ -4,17 +4,14 @@ from careamics.utils import BaseEnum
4
4
 
5
5
 
6
6
  class SupportedArchitecture(str, BaseEnum):
7
- """Supported architectures.
7
+ """Supported architectures."""
8
8
 
9
- # TODO add details, in particular where to find the API for the models
9
+ UNET = "UNet"
10
+ """UNet architecture used with N2V, CARE and Noise2Noise."""
10
11
 
11
- - UNet: classical UNet compatible with N2V2
12
- - VAE: variational Autoencoder
13
- - Custom: custom model registered with `@register_model` decorator
14
- """
12
+ LVAE = "LVAE"
13
+ """Ladder Variational Autoencoder used for muSplit and denoiSplit."""
15
14
 
16
- UNET = "UNet"
17
- VAE = "VAE"
18
- CUSTOM = (
19
- "Custom" # TODO all the others tags are small letters, except the architect
20
- )
15
+ CUSTOM = "custom"
16
+ """Keyword used for custom architectures provided by users and only compatible
17
+ with `FCNAlgorithmConfig` configuration."""
@@ -22,6 +22,8 @@ class SupportedLoss(str, BaseEnum):
22
22
  N2V = "n2v"
23
23
  # PN2V = "pn2v"
24
24
  # HDN = "hdn"
25
+ MUSPLIT = "musplit"
26
+ DENOISPLIT = "denoisplit"
27
+ DENOISPLIT_MUSPLIT = "denoisplit_musplit"
25
28
  # CE = "ce"
26
29
  # DICE = "dice"
27
- # CUSTOM = "custom" # TODO create mechanism for that
@@ -33,7 +33,7 @@ class N2VManipulateModel(TransformModel):
33
33
 
34
34
  name: Literal["N2VManipulate"] = "N2VManipulate"
35
35
  roi_size: int = Field(default=11, ge=3, le=21)
36
- masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
36
+ masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
37
37
  strategy: Literal["uniform", "median"] = Field(default="uniform")
38
38
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
39
39
  struct_mask_span: int = Field(default=5, ge=3, le=15)
@@ -0,0 +1,171 @@
1
+ """Algorithm configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal, Optional, Union
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
+
13
+ from .architectures import CustomModel, LVAEModel
14
+ from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
+ from .nm_model import MultiChannelNMConfig
16
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
17
+
18
+
19
+ class VAEAlgorithmConfig(BaseModel):
20
+ """Algorithm configuration.
21
+
22
+ This Pydantic model validates the parameters governing the components of the
23
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
24
+ and learning rate scheduler to use.
25
+
26
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
27
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
28
+ allows you to register your own architecture and select it using its name as
29
+ `name` in the custom pydantic model.
30
+
31
+ Attributes
32
+ ----------
33
+ algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
34
+ Algorithm to use.
35
+ loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
36
+ Loss function to use.
37
+ model : Union[LVAEModel, CustomModel]
38
+ Model architecture to use.
39
+ noise_model: Optional[MultiChannelNmModel]
40
+ Noise model to use.
41
+ noise_model_likelihood_model: Optional[NMLikelihoodModel]
42
+ Noise model likelihood model to use.
43
+ gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
44
+ Gaussian likelihood model to use.
45
+ optimizer : OptimizerModel, optional
46
+ Optimizer to use.
47
+ lr_scheduler : LrSchedulerModel, optional
48
+ Learning rate scheduler to use.
49
+
50
+ Raises
51
+ ------
52
+ ValueError
53
+ Algorithm parameter type validation errors.
54
+ ValueError
55
+ If the algorithm, loss and model are not compatible.
56
+
57
+ Examples
58
+ --------
59
+ # TODO add once finalized
60
+ """
61
+
62
+ # Pydantic class configuration
63
+ model_config = ConfigDict(
64
+ protected_namespaces=(), # allows to use model_* as a field name
65
+ validate_assignment=True,
66
+ extra="allow",
67
+ )
68
+
69
+ # Mandatory fields
70
+ # defined in SupportedAlgorithm
71
+ # TODO: Use supported Enum classes for typing?
72
+ # - values can still be passed as strings and they will be cast to Enum
73
+ algorithm_type: Literal["vae"]
74
+ algorithm: Literal["musplit", "denoisplit", "custom"]
75
+ loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
76
+ model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
77
+
78
+ # TODO: these are configs, change naming of attrs
79
+ noise_model: Optional[MultiChannelNMConfig] = None
80
+ noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
81
+ gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
82
+
83
+ # Optional fields
84
+ optimizer: OptimizerModel = OptimizerModel()
85
+ """Optimizer to use, defined in SupportedOptimizer."""
86
+
87
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
88
+
89
+ @model_validator(mode="after")
90
+ def algorithm_cross_validation(self: Self) -> Self:
91
+ """Validate the algorithm model based on `algorithm`.
92
+
93
+ Returns
94
+ -------
95
+ Self
96
+ The validated model.
97
+ """
98
+ # musplit
99
+ if self.algorithm == SupportedAlgorithm.MUSPLIT:
100
+ if self.loss != SupportedLoss.MUSPLIT:
101
+ raise ValueError(
102
+ f"Algorithm {self.algorithm} only supports loss `musplit`."
103
+ )
104
+
105
+ if self.algorithm == SupportedAlgorithm.DENOISPLIT:
106
+ if self.loss not in [
107
+ SupportedLoss.DENOISPLIT,
108
+ SupportedLoss.DENOISPLIT_MUSPLIT,
109
+ ]:
110
+ raise ValueError(
111
+ f"Algorithm {self.algorithm} only supports loss `denoisplit` "
112
+ "or `denoisplit_musplit."
113
+ )
114
+ if (
115
+ self.loss == SupportedLoss.DENOISPLIT
116
+ and self.model.predict_logvar is not None
117
+ ):
118
+ raise ValueError(
119
+ "Algorithm `denoisplit` with loss `denoisplit` only supports "
120
+ "`predict_logvar` as `None`."
121
+ )
122
+ if self.noise_model is None:
123
+ raise ValueError("Algorithm `denoisplit` requires a noise model.")
124
+ # TODO: what if algorithm is not musplit or denoisplit (HDN?)
125
+ return self
126
+
127
+ @model_validator(mode="after")
128
+ def output_channels_validation(self: Self) -> Self:
129
+ """Validate the consistency between number of out channels and noise models.
130
+
131
+ Returns
132
+ -------
133
+ Self
134
+ The validated model.
135
+ """
136
+ if self.noise_model is not None:
137
+ assert self.model.output_channels == len(self.noise_model.noise_models), (
138
+ f"Number of output channels ({self.model.output_channels}) must match "
139
+ f"the number of noise models ({len(self.noise_model.noise_models)})."
140
+ )
141
+ return self
142
+
143
+ @model_validator(mode="after")
144
+ def predict_logvar_validation(self: Self) -> Self:
145
+ """Validate the consistency of `predict_logvar` throughout the model.
146
+
147
+ Returns
148
+ -------
149
+ Self
150
+ The validated model.
151
+ """
152
+ if self.gaussian_likelihood_model is not None:
153
+ assert (
154
+ self.model.predict_logvar
155
+ == self.gaussian_likelihood_model.predict_logvar
156
+ ), (
157
+ f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
158
+ "Gaussian likelihood model `predict_logvar` "
159
+ f"({self.gaussian_likelihood_model.predict_logvar}).",
160
+ )
161
+ return self
162
+
163
+ def __str__(self) -> str:
164
+ """Pretty string representing the configuration.
165
+
166
+ Returns
167
+ -------
168
+ str
169
+ Pretty string.
170
+ """
171
+ return pformat(self.model_dump())
@@ -0,0 +1,282 @@
1
+ """Functions to reimplement the tiling in the Disentangle repository."""
2
+
3
+ import builtins
4
+ import itertools
5
+ from typing import Any, Generator, Optional, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.config.tile_information import TileInformation
11
+
12
+
13
+ def extract_tiles(
14
+ arr: NDArray,
15
+ tile_size: NDArray[np.int_],
16
+ overlaps: NDArray[np.int_],
17
+ padding_kwargs: Optional[dict[str, Any]] = None,
18
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
19
+ """Generate tiles from the input array with specified overlap.
20
+
21
+ The tiles cover the whole array; which will be additionally padded, to ensure that
22
+ the section of the tile that contributes to the final image comes from the center
23
+ of the tile.
24
+
25
+ The method returns a generator that yields tuples of array and tile information,
26
+ the latter includes whether the tile is the last one, the coordinates of the
27
+ overlap crop, and the coordinates of the stitched tile.
28
+
29
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
30
+ where C can be a singleton.
31
+
32
+ Parameters
33
+ ----------
34
+ arr : np.ndarray
35
+ Array of shape (S, C, (Z), Y, X).
36
+ tile_size : 1D numpy.ndarray of tuple
37
+ Tile sizes in each dimension, of length 2 or 3.
38
+ overlaps : 1D numpy.ndarray of tuple
39
+ Overlap values in each dimension, of length 2 or 3.
40
+ padding_kwargs : dict, optional
41
+ The arguments of `np.pad` after the first two arguments, `array` and
42
+ `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
43
+ `numpy.pad` docs:
44
+ https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
45
+
46
+ Yields
47
+ ------
48
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
49
+ Tile generator, yields the tile and additional information.
50
+ """
51
+ if padding_kwargs is None:
52
+ padding_kwargs = {"mode": "reflect"}
53
+
54
+ # Iterate over num samples (S)
55
+ for sample_idx in range(arr.shape[0]):
56
+ sample = arr[sample_idx, ...]
57
+ data_shape = np.array(sample.shape)
58
+
59
+ # add padding to ensure evenly spaced & overlapping tiles.
60
+ spatial_padding = compute_padding(data_shape, tile_size, overlaps)
61
+ padding = ((0, 0), *spatial_padding)
62
+ sample = np.pad(sample, padding, **padding_kwargs)
63
+
64
+ # The number of tiles in each dimension, should be of length 2 or 3
65
+ tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
66
+ # itertools.product is equivalent of nested loops
67
+
68
+ stitch_size = tile_size - overlaps
69
+ for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]):
70
+
71
+ # calculate crop coordinates
72
+ crop_coords_start = np.array(tile_grid_coords) * stitch_size
73
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
74
+ ...,
75
+ *[
76
+ slice(coords, coords + extent)
77
+ for coords, extent in zip(crop_coords_start, tile_size)
78
+ ],
79
+ )
80
+ tile = sample[crop_slices]
81
+
82
+ tile_info = compute_tile_info(
83
+ np.array(tile_grid_coords),
84
+ np.array(data_shape),
85
+ np.array(tile_size),
86
+ np.array(overlaps),
87
+ sample_idx,
88
+ )
89
+ # TODO: kinda weird this is a generator,
90
+ # -> doesn't really save memory ? Don't think there are any places the
91
+ # tiles are not exracted all at the same time.
92
+ # Although I guess it would make sense for a zarr tile extractor.
93
+ yield tile, tile_info
94
+
95
+
96
+ def compute_tile_info(
97
+ tile_grid_coords: NDArray[np.int_],
98
+ data_shape: NDArray[np.int_],
99
+ tile_size: NDArray[np.int_],
100
+ overlaps: NDArray[np.int_],
101
+ sample_id: int = 0,
102
+ ) -> TileInformation:
103
+ """
104
+ Compute the tile information for a tile with the coordinates `tile_grid_coords`.
105
+
106
+ Parameters
107
+ ----------
108
+ tile_grid_coords : 1D np.array of int
109
+ The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
110
+ tiling the coordinates for the second tile in the first row of tiles would be
111
+ (0, 1).
112
+ data_shape : 1D np.array of int
113
+ The shape of the data, should be (C, (Z), Y, X) where Z is optional.
114
+ tile_size : 1D np.array of int
115
+ Tile sizes in each dimension, of length 2 or 3.
116
+ overlaps : 1D np.array of int
117
+ Overlap values in each dimension, of length 2 or 3.
118
+ sample_id : int, default=0
119
+ An ID to identify which sample a tile belongs to.
120
+
121
+ Returns
122
+ -------
123
+ TileInformation
124
+ Information that describes how to crop and stitch a tile to create a full image.
125
+ """
126
+ spatial_dims_shape = data_shape[-len(tile_size) :]
127
+
128
+ # The extent of the tile which will make up part of the stitched image.
129
+ stitch_size = tile_size - overlaps
130
+ stitch_coords_start = tile_grid_coords * stitch_size
131
+ stitch_coords_end = stitch_coords_start + stitch_size
132
+
133
+ tile_coords_start = stitch_coords_start - overlaps // 2
134
+
135
+ # --- replace out of bounds indices
136
+ out_of_lower_bound = stitch_coords_start < 0
137
+ out_of_upper_bound = stitch_coords_end > spatial_dims_shape
138
+ stitch_coords_start[out_of_lower_bound] = 0
139
+ stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
140
+
141
+ # --- calculate overlap crop coords
142
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
143
+ overlap_crop_coords_end = overlap_crop_coords_start + (
144
+ stitch_coords_end - stitch_coords_start
145
+ )
146
+
147
+ # --- combine start and end
148
+ stitch_coords = tuple(
149
+ (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
150
+ )
151
+ overlap_crop_coords = tuple(
152
+ (start, end)
153
+ for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
154
+ )
155
+
156
+ # --- Check if last tile
157
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
158
+ last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all()
159
+
160
+ tile_info = TileInformation(
161
+ array_shape=data_shape,
162
+ last_tile=last_tile,
163
+ overlap_crop_coords=overlap_crop_coords,
164
+ stitch_coords=stitch_coords,
165
+ sample_id=sample_id,
166
+ )
167
+ return tile_info
168
+
169
+
170
+ def compute_padding(
171
+ data_shape: NDArray[np.int_],
172
+ tile_size: NDArray[np.int_],
173
+ overlaps: NDArray[np.int_],
174
+ ) -> tuple[tuple[int, int], ...]:
175
+ """
176
+ Calculate padding to ensure stitched data comes from the center of a tile.
177
+
178
+ Padding is added to an array with shape `data_shape` so that when tiles are
179
+ stitched together, the data used always comes from the center of a tile, even for
180
+ tiles at the boundaries of the array.
181
+
182
+ Parameters
183
+ ----------
184
+ data_shape : 1D numpy.array of int
185
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
186
+ tile_size : 1D numpy.array of int
187
+ The tile size in each dimension, ((Z), Y, X).
188
+ overlaps : 1D numpy.array of int
189
+ The tile overlap in each dimension, ((Z), Y, X).
190
+
191
+ Returns
192
+ -------
193
+ tuple of (int, int)
194
+ A tuple specifying the padding to add in each dimension, each element is a two
195
+ element tuple specifying the padding to add before and after the data. This
196
+ can be used as the `pad_width` argument to `numpy.pad`.
197
+ """
198
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
199
+ covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
200
+
201
+ pad_before = overlaps // 2
202
+ pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
203
+
204
+ return tuple((before, after) for before, after in zip(pad_before, pad_after))
205
+
206
+
207
+ def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
208
+ """Calculate the number of tiles in a specific dimension.
209
+
210
+ Parameters
211
+ ----------
212
+ axis_size : int
213
+ The length of the data for in a specific dimension.
214
+ tile_size : int
215
+ The length of the tiles in a specific dimension.
216
+ overlap : int
217
+ The tile overlap in a specific dimension.
218
+
219
+ Returns
220
+ -------
221
+ int
222
+ The number of tiles that fit in one dimension given the arguments.
223
+ """
224
+ return int(np.ceil(axis_size / (tile_size - overlap)))
225
+
226
+
227
+ def total_n_tiles(
228
+ data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
229
+ ) -> int:
230
+ """Calculate The total number of tiles over all dimensions.
231
+
232
+ Parameters
233
+ ----------
234
+ data_shape : 1D numpy.array of int
235
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
236
+ tile_size : 1D numpy.array of int
237
+ The tile size in each dimension, ((Z), Y, X).
238
+ overlaps : 1D numpy.array of int
239
+ The tile overlap in each dimension, ((Z), Y, X).
240
+
241
+
242
+ Returns
243
+ -------
244
+ int
245
+ The total number of tiles over all dimensions.
246
+ """
247
+ result = 1
248
+ # assume spatial dimension are the last dimensions so iterate backwards
249
+ for i in range(-1, -len(tile_size) - 1, -1):
250
+ result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
251
+
252
+ return result
253
+
254
+
255
+ def compute_tile_grid_shape(
256
+ data_shape: NDArray[np.int_],
257
+ tile_size: NDArray[np.int_],
258
+ overlaps: NDArray[np.int_],
259
+ ) -> tuple[int, ...]:
260
+ """Calculate the number of tiles in each dimension.
261
+
262
+ This can be thought of as a grid of tiles.
263
+
264
+ Parameters
265
+ ----------
266
+ data_shape : 1D numpy.array of int
267
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
268
+ tile_size : 1D numpy.array of int
269
+ The tile size in each dimension, ((Z), Y, X).
270
+ overlaps : 1D numpy.array of int
271
+ The tile overlap in each dimension, ((Z), Y, X).
272
+
273
+ Returns
274
+ -------
275
+ tuple of int
276
+ The number of tiles in each direction, ((Z, Y, X)).
277
+ """
278
+ shape = [0 for _ in range(len(tile_size))]
279
+ # assume spatial dimension are the last dimensions so iterate backwards
280
+ for i in range(-1, -len(tile_size) - 1, -1):
281
+ shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
282
+ return tuple(shape)