careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,20 @@
1
+ """Type used to represent all transformations users can create."""
2
+
3
+ from typing import Union
4
+
5
+ from pydantic import Discriminator
6
+ from typing_extensions import Annotated
7
+
8
+ from .n2v_manipulate_model import N2VManipulateModel
9
+ from .xy_flip_model import XYFlipModel
10
+ from .xy_random_rotate90_model import XYRandomRotate90Model
11
+
12
+ TRANSFORMS_UNION = Annotated[
13
+ Union[
14
+ XYFlipModel,
15
+ XYRandomRotate90Model,
16
+ N2VManipulateModel,
17
+ ],
18
+ Discriminator("name"), # used to tell the different transform models apart
19
+ ]
20
+ """Available transforms in CAREamics."""
@@ -0,0 +1,137 @@
1
+ """Algorithm configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal, Optional, Union
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
+
13
+ from .architectures import CustomModel, LVAEModel
14
+ from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
+ from .nm_model import MultiChannelNMConfig
16
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
17
+
18
+
19
+ class VAEAlgorithmConfig(BaseModel):
20
+ """Algorithm configuration.
21
+
22
+ # TODO
23
+
24
+ Examples
25
+ --------
26
+ # TODO add once finalized
27
+ """
28
+
29
+ # Pydantic class configuration
30
+ model_config = ConfigDict(
31
+ protected_namespaces=(), # allows to use model_* as a field name
32
+ validate_assignment=True,
33
+ extra="allow",
34
+ )
35
+
36
+ # Mandatory fields
37
+ # defined in SupportedAlgorithm
38
+ # TODO: Use supported Enum classes for typing?
39
+ # - values can still be passed as strings and they will be cast to Enum
40
+ algorithm: Literal["musplit", "denoisplit"]
41
+ loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
42
+ model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
43
+
44
+ # TODO: these are configs, change naming of attrs
45
+ noise_model: Optional[MultiChannelNMConfig] = None
46
+ noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
47
+ gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
48
+
49
+ # Optional fields
50
+ optimizer: OptimizerModel = OptimizerModel()
51
+ """Optimizer to use, defined in SupportedOptimizer."""
52
+
53
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
54
+
55
+ @model_validator(mode="after")
56
+ def algorithm_cross_validation(self: Self) -> Self:
57
+ """Validate the algorithm model based on `algorithm`.
58
+
59
+ Returns
60
+ -------
61
+ Self
62
+ The validated model.
63
+ """
64
+ # musplit
65
+ if self.algorithm == SupportedAlgorithm.MUSPLIT:
66
+ if self.loss != SupportedLoss.MUSPLIT:
67
+ raise ValueError(
68
+ f"Algorithm {self.algorithm} only supports loss `musplit`."
69
+ )
70
+
71
+ if self.algorithm == SupportedAlgorithm.DENOISPLIT:
72
+ if self.loss not in [
73
+ SupportedLoss.DENOISPLIT,
74
+ SupportedLoss.DENOISPLIT_MUSPLIT,
75
+ ]:
76
+ raise ValueError(
77
+ f"Algorithm {self.algorithm} only supports loss `denoisplit` "
78
+ "or `denoisplit_musplit."
79
+ )
80
+ if (
81
+ self.loss == SupportedLoss.DENOISPLIT
82
+ and self.model.predict_logvar is not None
83
+ ):
84
+ raise ValueError(
85
+ "Algorithm `denoisplit` with loss `denoisplit` only supports "
86
+ "`predict_logvar` as `None`."
87
+ )
88
+ if self.noise_model is None:
89
+ raise ValueError("Algorithm `denoisplit` requires a noise model.")
90
+ # TODO: what if algorithm is not musplit or denoisplit (HDN?)
91
+ return self
92
+
93
+ @model_validator(mode="after")
94
+ def output_channels_validation(self: Self) -> Self:
95
+ """Validate the consistency between number of out channels and noise models.
96
+
97
+ Returns
98
+ -------
99
+ Self
100
+ The validated model.
101
+ """
102
+ if self.noise_model is not None:
103
+ assert self.model.output_channels == len(self.noise_model.noise_models), (
104
+ f"Number of output channels ({self.model.output_channels}) must match "
105
+ f"the number of noise models ({len(self.noise_model.noise_models)})."
106
+ )
107
+ return self
108
+
109
+ @model_validator(mode="after")
110
+ def predict_logvar_validation(self: Self) -> Self:
111
+ """Validate the consistency of `predict_logvar` throughout the model.
112
+
113
+ Returns
114
+ -------
115
+ Self
116
+ The validated model.
117
+ """
118
+ if self.gaussian_likelihood_model is not None:
119
+ assert (
120
+ self.model.predict_logvar
121
+ == self.gaussian_likelihood_model.predict_logvar
122
+ ), (
123
+ f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
124
+ "Gaussian likelihood model `predict_logvar` "
125
+ f"({self.gaussian_likelihood_model.predict_logvar}).",
126
+ )
127
+ return self
128
+
129
+ def __str__(self) -> str:
130
+ """Pretty string representing the configuration.
131
+
132
+ Returns
133
+ -------
134
+ str
135
+ Pretty string.
136
+ """
137
+ return pformat(self.model_dump())
@@ -0,0 +1,364 @@
1
+ """Functions to reimplement the tiling in the Disentangle repository."""
2
+
3
+ import builtins
4
+ import itertools
5
+ from typing import Any, Generator, Optional, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.config.tile_information import TileInformation
11
+ from careamics.lvae_training.dataset.utils.index_manager import GridIndexManager
12
+
13
+
14
+ def extract_tiles(
15
+ arr: NDArray,
16
+ tile_size: NDArray[np.int_],
17
+ overlaps: NDArray[np.int_],
18
+ padding_kwargs: Optional[dict[str, Any]] = None,
19
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
20
+ """Generate tiles from the input array with specified overlap.
21
+
22
+ The tiles cover the whole array; which will be additionally padded, to ensure that
23
+ the section of the tile that contributes to the final image comes from the center
24
+ of the tile.
25
+
26
+ The method returns a generator that yields tuples of array and tile information,
27
+ the latter includes whether the tile is the last one, the coordinates of the
28
+ overlap crop, and the coordinates of the stitched tile.
29
+
30
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
31
+ where C can be a singleton.
32
+
33
+ Parameters
34
+ ----------
35
+ arr : np.ndarray
36
+ Array of shape (S, C, (Z), Y, X).
37
+ tile_size : 1D numpy.ndarray of tuple
38
+ Tile sizes in each dimension, of length 2 or 3.
39
+ overlaps : 1D numpy.ndarray of tuple
40
+ Overlap values in each dimension, of length 2 or 3.
41
+ padding_kwargs : dict, optional
42
+ The arguments of `np.pad` after the first two arguments, `array` and
43
+ `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
44
+ `numpy.pad` docs:
45
+ https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
46
+
47
+ Yields
48
+ ------
49
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
50
+ Tile generator, yields the tile and additional information.
51
+ """
52
+ if padding_kwargs is None:
53
+ padding_kwargs = {"mode": "reflect"}
54
+
55
+ # Iterate over num samples (S)
56
+ for sample_idx in range(arr.shape[0]):
57
+ sample = arr[sample_idx, ...]
58
+ data_shape = np.array(sample.shape)
59
+
60
+ # add padding to ensure evenly spaced & overlapping tiles.
61
+ spatial_padding = compute_padding(data_shape, tile_size, overlaps)
62
+ padding = ((0, 0), *spatial_padding)
63
+ sample = np.pad(sample, padding, **padding_kwargs)
64
+
65
+ # The number of tiles in each dimension, should be of length 2 or 3
66
+ tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
67
+ # itertools.product is equivalent of nested loops
68
+
69
+ stitch_size = tile_size - overlaps
70
+ for tile_grid_indices in itertools.product(
71
+ *[range(n) for n in tile_grid_shape]
72
+ ):
73
+
74
+ # calculate crop coordinates
75
+ crop_coords_start = np.array(tile_grid_indices) * stitch_size
76
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
77
+ ...,
78
+ *[
79
+ slice(coords, coords + extent)
80
+ for coords, extent in zip(crop_coords_start, tile_size)
81
+ ],
82
+ )
83
+ tile = sample[crop_slices]
84
+
85
+ tile_info = compute_tile_info(
86
+ np.array(tile_grid_indices),
87
+ np.array(data_shape),
88
+ np.array(tile_size),
89
+ np.array(overlaps),
90
+ sample_idx,
91
+ )
92
+ # TODO: kinda weird this is a generator,
93
+ # -> doesn't really save memory ? Don't think there are any places the
94
+ # tiles are not exracted all at the same time.
95
+ # Although I guess it would make sense for a zarr tile extractor.
96
+ yield tile, tile_info
97
+
98
+
99
+ def compute_tile_info_legacy(
100
+ grid_index_manager: GridIndexManager, index: int
101
+ ) -> TileInformation:
102
+ """
103
+ Compute the tile information for a tile at a given dataset index.
104
+
105
+ Parameters
106
+ ----------
107
+ grid_index_manager : GridIndexManager
108
+ The grid index manager that keeps track of tile locations.
109
+ index : int
110
+ The dataset index.
111
+
112
+ Returns
113
+ -------
114
+ TileInformation
115
+ Information that describes how to crop and stitch a tile to create a full image.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If `grid_index_manager.data_shape` does not have 4 or 5 dimensions.
121
+ """
122
+ data_shape = np.array(grid_index_manager.data_shape)
123
+ if len(data_shape) == 5:
124
+ n_spatial_dims = 3
125
+ elif len(data_shape) == 4:
126
+ n_spatial_dims = 2
127
+ else:
128
+ raise ValueError("Data shape must have 4 or 5 dimensions, equating to SC(Z)YX.")
129
+
130
+ stitch_coords_start = np.array(
131
+ grid_index_manager.get_location_from_dataset_idx(index)
132
+ )
133
+ stitch_coords_end = stitch_coords_start + np.array(grid_index_manager.grid_shape)
134
+
135
+ tile_coords_start = stitch_coords_start - grid_index_manager.patch_offset()
136
+
137
+ # --- replace out of bounds indices
138
+ out_of_lower_bound = stitch_coords_start < 0
139
+ out_of_upper_bound = stitch_coords_end > data_shape
140
+ stitch_coords_start[out_of_lower_bound] = 0
141
+ stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
142
+
143
+ # TODO: TilingMode not in current version
144
+ # if grid_index_manager.tiling_mode == TilingMode.ShiftBoundary:
145
+ # for dim in range(len(stitch_coords_start)):
146
+ # if tile_coords_start[dim] == 0:
147
+ # stitch_coords_start[dim] = 0
148
+ # if tile_coords_end[dim] == grid_index_manager.data_shape[dim]:
149
+ # tile_coords_end [dim]= grid_index_manager.data_shape[dim]
150
+
151
+ # --- calculate overlap crop coords
152
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
153
+ overlap_crop_coords_end = overlap_crop_coords_start + (
154
+ stitch_coords_end - stitch_coords_start
155
+ )
156
+
157
+ last_tile = index == grid_index_manager.total_grid_count() - 1
158
+
159
+ # --- combine start and end
160
+ stitch_coords = tuple(
161
+ (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
162
+ )
163
+ overlap_crop_coords = tuple(
164
+ (start, end)
165
+ for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
166
+ )
167
+
168
+ tile_info = TileInformation(
169
+ array_shape=data_shape[1:], # remove S dim
170
+ last_tile=last_tile,
171
+ overlap_crop_coords=overlap_crop_coords[-n_spatial_dims:],
172
+ stitch_coords=stitch_coords[-n_spatial_dims:],
173
+ sample_id=0,
174
+ )
175
+ return tile_info
176
+
177
+
178
+ def compute_tile_info(
179
+ tile_grid_indices: NDArray[np.int_],
180
+ data_shape: NDArray[np.int_],
181
+ tile_size: NDArray[np.int_],
182
+ overlaps: NDArray[np.int_],
183
+ sample_id: int = 0,
184
+ ) -> TileInformation:
185
+ """
186
+ Compute the tile information for a tile with the coordinates `tile_grid_indices`.
187
+
188
+ Parameters
189
+ ----------
190
+ tile_grid_indices : 1D np.array of int
191
+ The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
192
+ tiling the coordinates for the second tile in the first row of tiles would be
193
+ (0, 1).
194
+ data_shape : 1D np.array of int
195
+ The shape of the data, should be (C, (Z), Y, X) where Z is optional.
196
+ tile_size : 1D np.array of int
197
+ Tile sizes in each dimension, of length 2 or 3.
198
+ overlaps : 1D np.array of int
199
+ Overlap values in each dimension, of length 2 or 3.
200
+ sample_id : int, default=0
201
+ An ID to identify which sample a tile belongs to.
202
+
203
+ Returns
204
+ -------
205
+ TileInformation
206
+ Information that describes how to crop and stitch a tile to create a full image.
207
+ """
208
+ spatial_dims_shape = data_shape[-len(tile_size) :]
209
+
210
+ # The extent of the tile which will make up part of the stitched image.
211
+ stitch_size = tile_size - overlaps
212
+ stitch_coords_start = tile_grid_indices * stitch_size
213
+ stitch_coords_end = stitch_coords_start + stitch_size
214
+
215
+ tile_coords_start = stitch_coords_start - overlaps // 2
216
+
217
+ # --- replace out of bounds indices
218
+ out_of_lower_bound = stitch_coords_start < 0
219
+ out_of_upper_bound = stitch_coords_end > spatial_dims_shape
220
+ stitch_coords_start[out_of_lower_bound] = 0
221
+ stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
222
+
223
+ # --- calculate overlap crop coords
224
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
225
+ overlap_crop_coords_end = overlap_crop_coords_start + (
226
+ stitch_coords_end - stitch_coords_start
227
+ )
228
+
229
+ # --- combine start and end
230
+ stitch_coords = tuple(
231
+ (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
232
+ )
233
+ overlap_crop_coords = tuple(
234
+ (start, end)
235
+ for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
236
+ )
237
+
238
+ # --- Check if last tile
239
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
240
+ last_tile = (tile_grid_indices == (tile_grid_shape - 1)).all()
241
+
242
+ tile_info = TileInformation(
243
+ array_shape=data_shape,
244
+ last_tile=last_tile,
245
+ overlap_crop_coords=overlap_crop_coords,
246
+ stitch_coords=stitch_coords,
247
+ sample_id=sample_id,
248
+ )
249
+ return tile_info
250
+
251
+
252
+ def compute_padding(
253
+ data_shape: NDArray[np.int_],
254
+ tile_size: NDArray[np.int_],
255
+ overlaps: NDArray[np.int_],
256
+ ) -> tuple[tuple[int, int], ...]:
257
+ """
258
+ Calculate padding to ensure stitched data comes from the center of a tile.
259
+
260
+ Padding is added to an array with shape `data_shape` so that when tiles are
261
+ stitched together, the data used always comes from the center of a tile, even for
262
+ tiles at the boundaries of the array.
263
+
264
+ Parameters
265
+ ----------
266
+ data_shape : 1D numpy.array of int
267
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
268
+ tile_size : 1D numpy.array of int
269
+ The tile size in each dimension, ((Z), Y, X).
270
+ overlaps : 1D numpy.array of int
271
+ The tile overlap in each dimension, ((Z), Y, X).
272
+
273
+ Returns
274
+ -------
275
+ tuple of (int, int)
276
+ A tuple specifying the padding to add in each dimension, each element is a two
277
+ element tuple specifying the padding to add before and after the data. This
278
+ can be used as the `pad_width` argument to `numpy.pad`.
279
+ """
280
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
281
+ covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
282
+
283
+ pad_before = overlaps // 2
284
+ pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
285
+
286
+ return tuple((before, after) for before, after in zip(pad_before, pad_after))
287
+
288
+
289
+ def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
290
+ """Calculate the number of tiles in a specific dimension.
291
+
292
+ Parameters
293
+ ----------
294
+ axis_size : int
295
+ The length of the data for in a specific dimension.
296
+ tile_size : int
297
+ The length of the tiles in a specific dimension.
298
+ overlap : int
299
+ The tile overlap in a specific dimension.
300
+
301
+ Returns
302
+ -------
303
+ int
304
+ The number of tiles that fit in one dimension given the arguments.
305
+ """
306
+ return int(np.ceil(axis_size / (tile_size - overlap)))
307
+
308
+
309
+ def total_n_tiles(
310
+ data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
311
+ ) -> int:
312
+ """Calculate The total number of tiles over all dimensions.
313
+
314
+ Parameters
315
+ ----------
316
+ data_shape : 1D numpy.array of int
317
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
318
+ tile_size : 1D numpy.array of int
319
+ The tile size in each dimension, ((Z), Y, X).
320
+ overlaps : 1D numpy.array of int
321
+ The tile overlap in each dimension, ((Z), Y, X).
322
+
323
+
324
+ Returns
325
+ -------
326
+ int
327
+ The total number of tiles over all dimensions.
328
+ """
329
+ result = 1
330
+ # assume spatial dimension are the last dimensions so iterate backwards
331
+ for i in range(-1, -len(tile_size) - 1, -1):
332
+ result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
333
+
334
+ return result
335
+
336
+
337
+ def compute_tile_grid_shape(
338
+ data_shape: NDArray[np.int_],
339
+ tile_size: NDArray[np.int_],
340
+ overlaps: NDArray[np.int_],
341
+ ) -> tuple[int, ...]:
342
+ """Calculate the number of tiles in each dimension.
343
+
344
+ This can be thought of as a grid of tiles.
345
+
346
+ Parameters
347
+ ----------
348
+ data_shape : 1D numpy.array of int
349
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
350
+ tile_size : 1D numpy.array of int
351
+ The tile size in each dimension, ((Z), Y, X).
352
+ overlaps : 1D numpy.array of int
353
+ The tile overlap in each dimension, ((Z), Y, X).
354
+
355
+ Returns
356
+ -------
357
+ tuple of int
358
+ The number of tiles in each direction, ((Z, Y, X)).
359
+ """
360
+ shape = [0 for _ in range(len(tile_size))]
361
+ # assume spatial dimension are the last dimensions so iterate backwards
362
+ for i in range(-1, -len(tile_size) - 1, -1):
363
+ shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
364
+ return tuple(shape)
@@ -1,4 +1,4 @@
1
- """Funtions to read tiff images."""
1
+ """Functions to read tiff images."""
2
2
 
3
3
  import logging
4
4
  from fnmatch import fnmatch
@@ -1,7 +1,8 @@
1
1
  """CAREamics PyTorch Lightning modules."""
2
2
 
3
3
  __all__ = [
4
- "CAREamicsModule",
4
+ "FCNModule",
5
+ "VAEModule",
5
6
  "create_careamics_module",
6
7
  "TrainDataModule",
7
8
  "create_train_datamodule",
@@ -12,6 +13,6 @@ __all__ = [
12
13
  ]
13
14
 
14
15
  from .callbacks import HyperParametersCallback, ProgressBarCallback
15
- from .lightning_module import CAREamicsModule, create_careamics_module
16
+ from .lightning_module import FCNModule, VAEModule, create_careamics_module
16
17
  from .predict_data_module import PredictDataModule, create_predict_datamodule
17
18
  from .train_data_module import TrainDataModule, create_train_datamodule
@@ -10,7 +10,7 @@ class HyperParametersCallback(Callback):
10
10
  """
11
11
  Callback allowing saving CAREamics configuration as hyperparameters in the model.
12
12
 
13
- This allows saving the configuration as dictionnary in the checkpoints, and
13
+ This allows saving the configuration as dictionary in the checkpoints, and
14
14
  loading it subsequently in a CAREamist instance.
15
15
 
16
16
  Parameters
@@ -1,4 +1,4 @@
1
- """Module containing convienience function to create `WriteStrategy`."""
1
+ """Module containing convenience function to create `WriteStrategy`."""
2
2
 
3
3
  from typing import Any, Optional
4
4