careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""Module containing different strategies for writing predictions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Optional, Protocol, Sequence, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
from careamics.config.tile_information import TileInformation
|
|
12
|
+
from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
|
|
13
|
+
from careamics.file_io import WriteFunc
|
|
14
|
+
from careamics.prediction_utils import stitch_prediction_single
|
|
15
|
+
|
|
16
|
+
from .file_path_utils import create_write_file_path, get_sample_file_path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WriteStrategy(Protocol):
|
|
20
|
+
"""Protocol for write strategy classes."""
|
|
21
|
+
|
|
22
|
+
def write_batch(
|
|
23
|
+
self,
|
|
24
|
+
trainer: Trainer,
|
|
25
|
+
pl_module: LightningModule,
|
|
26
|
+
prediction: Any, # TODO: change to expected type
|
|
27
|
+
batch_indices: Optional[Sequence[int]],
|
|
28
|
+
batch: Any, # TODO: change to expected type
|
|
29
|
+
batch_idx: int,
|
|
30
|
+
dataloader_idx: int,
|
|
31
|
+
dirpath: Path,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
WriteStrategy subclasses must contain this function to write a batch.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
trainer : Trainer
|
|
39
|
+
PyTorch Lightning Trainer.
|
|
40
|
+
pl_module : LightningModule
|
|
41
|
+
PyTorch Lightning LightningModule.
|
|
42
|
+
prediction : Any
|
|
43
|
+
Predictions on `batch`.
|
|
44
|
+
batch_indices : sequence of int
|
|
45
|
+
Indices identifying the samples in the batch.
|
|
46
|
+
batch : Any
|
|
47
|
+
Input batch.
|
|
48
|
+
batch_idx : int
|
|
49
|
+
Batch index.
|
|
50
|
+
dataloader_idx : int
|
|
51
|
+
Dataloader index.
|
|
52
|
+
dirpath : Path
|
|
53
|
+
Path to directory to save predictions to.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CacheTiles(WriteStrategy):
|
|
58
|
+
"""
|
|
59
|
+
A write strategy that will cache tiles.
|
|
60
|
+
|
|
61
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
62
|
+
prediction is saved.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
write_func : WriteFunc
|
|
67
|
+
Function used to save predictions.
|
|
68
|
+
write_extension : str
|
|
69
|
+
Extension added to prediction file paths.
|
|
70
|
+
write_func_kwargs : dict of {str: Any}
|
|
71
|
+
Extra kwargs to pass to `write_func`.
|
|
72
|
+
|
|
73
|
+
Attributes
|
|
74
|
+
----------
|
|
75
|
+
write_func : WriteFunc
|
|
76
|
+
Function used to save predictions.
|
|
77
|
+
write_extension : str
|
|
78
|
+
Extension added to prediction file paths.
|
|
79
|
+
write_func_kwargs : dict of {str: Any}
|
|
80
|
+
Extra kwargs to pass to `write_func`.
|
|
81
|
+
tile_cache : list of numpy.ndarray
|
|
82
|
+
Tiles cached for stitching prediction.
|
|
83
|
+
tile_info_cache : list of TileInformation
|
|
84
|
+
Cached tile information for stitching prediction.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
write_func: WriteFunc,
|
|
90
|
+
write_extension: str,
|
|
91
|
+
write_func_kwargs: dict[str, Any],
|
|
92
|
+
) -> None:
|
|
93
|
+
"""
|
|
94
|
+
A write strategy that will cache tiles.
|
|
95
|
+
|
|
96
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
97
|
+
prediction is saved.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
write_func : WriteFunc
|
|
102
|
+
Function used to save predictions.
|
|
103
|
+
write_extension : str
|
|
104
|
+
Extension added to prediction file paths.
|
|
105
|
+
write_func_kwargs : dict of {str: Any}
|
|
106
|
+
Extra kwargs to pass to `write_func`.
|
|
107
|
+
"""
|
|
108
|
+
super().__init__()
|
|
109
|
+
|
|
110
|
+
self.write_func: WriteFunc = write_func
|
|
111
|
+
self.write_extension: str = write_extension
|
|
112
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
113
|
+
|
|
114
|
+
# where tiles will be cached until a whole image has been predicted
|
|
115
|
+
self.tile_cache: list[NDArray] = []
|
|
116
|
+
self.tile_info_cache: list[TileInformation] = []
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def last_tiles(self) -> list[bool]:
|
|
120
|
+
"""
|
|
121
|
+
List of bool to determine whether each tile in the cache is the last tile.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
list of bool
|
|
126
|
+
Whether each tile in the tile cache is the last tile.
|
|
127
|
+
"""
|
|
128
|
+
return [tile_info.last_tile for tile_info in self.tile_info_cache]
|
|
129
|
+
|
|
130
|
+
def write_batch(
|
|
131
|
+
self,
|
|
132
|
+
trainer: Trainer,
|
|
133
|
+
pl_module: LightningModule,
|
|
134
|
+
prediction: tuple[NDArray, list[TileInformation]],
|
|
135
|
+
batch_indices: Optional[Sequence[int]],
|
|
136
|
+
batch: tuple[NDArray, list[TileInformation]],
|
|
137
|
+
batch_idx: int,
|
|
138
|
+
dataloader_idx: int,
|
|
139
|
+
dirpath: Path,
|
|
140
|
+
) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Cache tiles until the last tile is predicted; save the stitched prediction.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
trainer : Trainer
|
|
147
|
+
PyTorch Lightning Trainer.
|
|
148
|
+
pl_module : LightningModule
|
|
149
|
+
PyTorch Lightning LightningModule.
|
|
150
|
+
prediction : Any
|
|
151
|
+
Predictions on `batch`.
|
|
152
|
+
batch_indices : sequence of int
|
|
153
|
+
Indices identifying the samples in the batch.
|
|
154
|
+
batch : Any
|
|
155
|
+
Input batch.
|
|
156
|
+
batch_idx : int
|
|
157
|
+
Batch index.
|
|
158
|
+
dataloader_idx : int
|
|
159
|
+
Dataloader index.
|
|
160
|
+
dirpath : Path
|
|
161
|
+
Path to directory to save predictions to.
|
|
162
|
+
"""
|
|
163
|
+
dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
164
|
+
dataloader: DataLoader = (
|
|
165
|
+
dataloaders[dataloader_idx]
|
|
166
|
+
if isinstance(dataloaders, list)
|
|
167
|
+
else dataloaders
|
|
168
|
+
)
|
|
169
|
+
dataset: IterableTiledPredDataset = dataloader.dataset
|
|
170
|
+
if not isinstance(dataset, IterableTiledPredDataset):
|
|
171
|
+
raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.")
|
|
172
|
+
|
|
173
|
+
# cache tiles (batches are split into single samples)
|
|
174
|
+
self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0]))
|
|
175
|
+
self.tile_info_cache.extend(prediction[1])
|
|
176
|
+
|
|
177
|
+
# save stitched prediction
|
|
178
|
+
if self._has_last_tile():
|
|
179
|
+
|
|
180
|
+
# get image tiles and remove them from the cache
|
|
181
|
+
tiles, tile_infos = self._get_image_tiles()
|
|
182
|
+
self._clear_cache()
|
|
183
|
+
|
|
184
|
+
# stitch prediction
|
|
185
|
+
prediction_image = stitch_prediction_single(
|
|
186
|
+
tiles=tiles, tile_infos=tile_infos
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# write prediction
|
|
190
|
+
sample_id = tile_infos[0].sample_id # need this to select correct file name
|
|
191
|
+
input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id)
|
|
192
|
+
file_path = create_write_file_path(
|
|
193
|
+
dirpath=dirpath,
|
|
194
|
+
file_path=input_file_path,
|
|
195
|
+
write_extension=self.write_extension,
|
|
196
|
+
)
|
|
197
|
+
self.write_func(
|
|
198
|
+
file_path=file_path, img=prediction_image[0], **self.write_func_kwargs
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def _has_last_tile(self) -> bool:
|
|
202
|
+
"""
|
|
203
|
+
Whether a last tile is contained in the cached tiles.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
bool
|
|
208
|
+
Whether a last tile is contained in the cached tiles.
|
|
209
|
+
"""
|
|
210
|
+
return any(self.last_tiles)
|
|
211
|
+
|
|
212
|
+
def _clear_cache(self) -> None:
|
|
213
|
+
"""Remove the tiles in the cache up to the first last tile."""
|
|
214
|
+
index = self._last_tile_index()
|
|
215
|
+
self.tile_cache = self.tile_cache[index + 1 :]
|
|
216
|
+
self.tile_info_cache = self.tile_info_cache[index + 1 :]
|
|
217
|
+
|
|
218
|
+
def _last_tile_index(self) -> int:
|
|
219
|
+
"""
|
|
220
|
+
Find the index of the last tile in the tile cache.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
int
|
|
225
|
+
Index of last tile.
|
|
226
|
+
|
|
227
|
+
Raises
|
|
228
|
+
------
|
|
229
|
+
ValueError
|
|
230
|
+
If there is no last tile in the tile cache.
|
|
231
|
+
"""
|
|
232
|
+
last_tiles = self.last_tiles
|
|
233
|
+
if not any(last_tiles):
|
|
234
|
+
raise ValueError("No last tile in the tile cache.")
|
|
235
|
+
index = np.where(last_tiles)[0][0]
|
|
236
|
+
return index
|
|
237
|
+
|
|
238
|
+
def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]:
|
|
239
|
+
"""
|
|
240
|
+
Get the tiles corresponding to a single image.
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
245
|
+
Tiles and tile information to stitch together a full image.
|
|
246
|
+
"""
|
|
247
|
+
index = self._last_tile_index()
|
|
248
|
+
tiles = self.tile_cache[: index + 1]
|
|
249
|
+
tile_infos = self.tile_info_cache[: index + 1]
|
|
250
|
+
return tiles, tile_infos
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class WriteTilesZarr(WriteStrategy):
|
|
254
|
+
"""Strategy to write tiles to Zarr file."""
|
|
255
|
+
|
|
256
|
+
def write_batch(
|
|
257
|
+
self,
|
|
258
|
+
trainer: Trainer,
|
|
259
|
+
pl_module: LightningModule,
|
|
260
|
+
prediction: Any,
|
|
261
|
+
batch_indices: Optional[Sequence[int]],
|
|
262
|
+
batch: Any,
|
|
263
|
+
batch_idx: int,
|
|
264
|
+
dataloader_idx: int,
|
|
265
|
+
dirpath: Path,
|
|
266
|
+
) -> None:
|
|
267
|
+
"""
|
|
268
|
+
Write tiles to zarr file.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
trainer : Trainer
|
|
273
|
+
PyTorch Lightning Trainer.
|
|
274
|
+
pl_module : LightningModule
|
|
275
|
+
PyTorch Lightning LightningModule.
|
|
276
|
+
prediction : Any
|
|
277
|
+
Predictions on `batch`.
|
|
278
|
+
batch_indices : sequence of int
|
|
279
|
+
Indices identifying the samples in the batch.
|
|
280
|
+
batch : Any
|
|
281
|
+
Input batch.
|
|
282
|
+
batch_idx : int
|
|
283
|
+
Batch index.
|
|
284
|
+
dataloader_idx : int
|
|
285
|
+
Dataloader index.
|
|
286
|
+
dirpath : Path
|
|
287
|
+
Path to directory to save predictions to.
|
|
288
|
+
|
|
289
|
+
Raises
|
|
290
|
+
------
|
|
291
|
+
NotImplementedError
|
|
292
|
+
"""
|
|
293
|
+
raise NotImplementedError
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class WriteImage(WriteStrategy):
|
|
297
|
+
"""
|
|
298
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
write_func : WriteFunc
|
|
303
|
+
Function used to save predictions.
|
|
304
|
+
write_extension : str
|
|
305
|
+
Extension added to prediction file paths.
|
|
306
|
+
write_func_kwargs : dict of {str: Any}
|
|
307
|
+
Extra kwargs to pass to `write_func`.
|
|
308
|
+
|
|
309
|
+
Attributes
|
|
310
|
+
----------
|
|
311
|
+
write_func : WriteFunc
|
|
312
|
+
Function used to save predictions.
|
|
313
|
+
write_extension : str
|
|
314
|
+
Extension added to prediction file paths.
|
|
315
|
+
write_func_kwargs : dict of {str: Any}
|
|
316
|
+
Extra kwargs to pass to `write_func`.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(
|
|
320
|
+
self,
|
|
321
|
+
write_func: WriteFunc,
|
|
322
|
+
write_extension: str,
|
|
323
|
+
write_func_kwargs: dict[str, Any],
|
|
324
|
+
) -> None:
|
|
325
|
+
"""
|
|
326
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
327
|
+
|
|
328
|
+
Parameters
|
|
329
|
+
----------
|
|
330
|
+
write_func : WriteFunc
|
|
331
|
+
Function used to save predictions.
|
|
332
|
+
write_extension : str
|
|
333
|
+
Extension added to prediction file paths.
|
|
334
|
+
write_func_kwargs : dict of {str: Any}
|
|
335
|
+
Extra kwargs to pass to `write_func`.
|
|
336
|
+
"""
|
|
337
|
+
super().__init__()
|
|
338
|
+
|
|
339
|
+
self.write_func: WriteFunc = write_func
|
|
340
|
+
self.write_extension: str = write_extension
|
|
341
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
342
|
+
|
|
343
|
+
def write_batch(
|
|
344
|
+
self,
|
|
345
|
+
trainer: Trainer,
|
|
346
|
+
pl_module: LightningModule,
|
|
347
|
+
prediction: NDArray,
|
|
348
|
+
batch_indices: Optional[Sequence[int]],
|
|
349
|
+
batch: NDArray,
|
|
350
|
+
batch_idx: int,
|
|
351
|
+
dataloader_idx: int,
|
|
352
|
+
dirpath: Path,
|
|
353
|
+
) -> None:
|
|
354
|
+
"""
|
|
355
|
+
Save full images.
|
|
356
|
+
|
|
357
|
+
Parameters
|
|
358
|
+
----------
|
|
359
|
+
trainer : Trainer
|
|
360
|
+
PyTorch Lightning Trainer.
|
|
361
|
+
pl_module : LightningModule
|
|
362
|
+
PyTorch Lightning LightningModule.
|
|
363
|
+
prediction : Any
|
|
364
|
+
Predictions on `batch`.
|
|
365
|
+
batch_indices : sequence of int
|
|
366
|
+
Indices identifying the samples in the batch.
|
|
367
|
+
batch : Any
|
|
368
|
+
Input batch.
|
|
369
|
+
batch_idx : int
|
|
370
|
+
Batch index.
|
|
371
|
+
dataloader_idx : int
|
|
372
|
+
Dataloader index.
|
|
373
|
+
dirpath : Path
|
|
374
|
+
Path to directory to save predictions to.
|
|
375
|
+
|
|
376
|
+
Raises
|
|
377
|
+
------
|
|
378
|
+
TypeError
|
|
379
|
+
If trainer prediction dataset is not `IterablePredDataset`.
|
|
380
|
+
"""
|
|
381
|
+
dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
382
|
+
dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls
|
|
383
|
+
ds: IterablePredDataset = dl.dataset
|
|
384
|
+
if not isinstance(ds, IterablePredDataset):
|
|
385
|
+
raise TypeError("Prediction dataset is not `IterablePredDataset`.")
|
|
386
|
+
|
|
387
|
+
for i in range(prediction.shape[0]):
|
|
388
|
+
prediction_image = prediction[0]
|
|
389
|
+
sample_id = batch_idx * dl.batch_size + i
|
|
390
|
+
input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id)
|
|
391
|
+
file_path = create_write_file_path(
|
|
392
|
+
dirpath=dirpath,
|
|
393
|
+
file_path=input_file_path,
|
|
394
|
+
write_extension=self.write_extension,
|
|
395
|
+
)
|
|
396
|
+
self.write_func(
|
|
397
|
+
file_path=file_path, img=prediction_image, **self.write_func_kwargs
|
|
398
|
+
)
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Module containing convienience function to create `WriteStrategy`."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from careamics.config.support import SupportedData
|
|
6
|
+
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
7
|
+
|
|
8
|
+
from .write_strategy import CacheTiles, WriteImage, WriteStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_write_strategy(
|
|
12
|
+
write_type: SupportedWriteType,
|
|
13
|
+
tiled: bool,
|
|
14
|
+
write_func: Optional[WriteFunc] = None,
|
|
15
|
+
write_extension: Optional[str] = None,
|
|
16
|
+
write_func_kwargs: Optional[dict[str, Any]] = None,
|
|
17
|
+
) -> WriteStrategy:
|
|
18
|
+
"""
|
|
19
|
+
Create a write strategy from convenient parameters.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
write_type : {"tiff", "custom"}
|
|
24
|
+
The data type to save as, includes custom.
|
|
25
|
+
tiled : bool
|
|
26
|
+
Whether the prediction will be tiled or not.
|
|
27
|
+
write_func : WriteFunc, optional
|
|
28
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
29
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
30
|
+
write_extension : str, optional
|
|
31
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
32
|
+
`write_type` an extension to save the data with must be passed.
|
|
33
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
34
|
+
Additional keyword arguments to be passed to the save function.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
WriteStrategy
|
|
39
|
+
A strategy for writing predicions.
|
|
40
|
+
|
|
41
|
+
Notes
|
|
42
|
+
-----
|
|
43
|
+
The `write_func` function signature must match that of the example below
|
|
44
|
+
```
|
|
45
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
The `write_func_kwargs` will be passed to the `write_func` doing the following:
|
|
49
|
+
```
|
|
50
|
+
write_func(file_path=file_path, img=img, **kwargs)
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
if write_func_kwargs is None:
|
|
54
|
+
write_func_kwargs = {}
|
|
55
|
+
|
|
56
|
+
write_strategy: WriteStrategy
|
|
57
|
+
if not tiled:
|
|
58
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
59
|
+
write_extension = select_write_extension(
|
|
60
|
+
write_type=write_type, write_extension=write_extension
|
|
61
|
+
)
|
|
62
|
+
write_strategy = WriteImage(
|
|
63
|
+
write_func=write_func,
|
|
64
|
+
write_extension=write_extension,
|
|
65
|
+
write_func_kwargs=write_func_kwargs,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# select CacheTiles or WriteTilesZarr (when implemented)
|
|
69
|
+
write_strategy = _create_tiled_write_strategy(
|
|
70
|
+
write_type=write_type,
|
|
71
|
+
write_func=write_func,
|
|
72
|
+
write_extension=write_extension,
|
|
73
|
+
write_func_kwargs=write_func_kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return write_strategy
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_tiled_write_strategy(
|
|
80
|
+
write_type: SupportedWriteType,
|
|
81
|
+
write_func: Optional[WriteFunc],
|
|
82
|
+
write_extension: Optional[str],
|
|
83
|
+
write_func_kwargs: dict[str, Any],
|
|
84
|
+
) -> WriteStrategy:
|
|
85
|
+
"""
|
|
86
|
+
Create a tiled write strategy.
|
|
87
|
+
|
|
88
|
+
Either `CacheTiles` for caching tiles until a whole image is predicted or
|
|
89
|
+
`WriteTilesZarr` for writing tiles directly to disk.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
write_type : {"tiff", "custom"}
|
|
94
|
+
The data type to save as, includes custom.
|
|
95
|
+
write_func : WriteFunc, optional
|
|
96
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
97
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
98
|
+
write_extension : str, optional
|
|
99
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
100
|
+
`write_type` an extension to save the data with must be passed.
|
|
101
|
+
write_func_kwargs : dict of {str: any}
|
|
102
|
+
Additional keyword arguments to be passed to the save function.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
WriteStrategy
|
|
107
|
+
A strategy for writing tiled predictions.
|
|
108
|
+
|
|
109
|
+
Raises
|
|
110
|
+
------
|
|
111
|
+
NotImplementedError
|
|
112
|
+
if `write_type="zarr" is chosen.
|
|
113
|
+
"""
|
|
114
|
+
# if write_type == SupportedData.ZARR:
|
|
115
|
+
# create *args, **kwargs
|
|
116
|
+
# return WriteTilesZarr(*args, **kwargs)
|
|
117
|
+
# else:
|
|
118
|
+
if write_type == "zarr":
|
|
119
|
+
raise NotImplementedError("Saving to zarr is not implemented yet.")
|
|
120
|
+
else:
|
|
121
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
122
|
+
write_extension = select_write_extension(
|
|
123
|
+
write_type=write_type, write_extension=write_extension
|
|
124
|
+
)
|
|
125
|
+
return CacheTiles(
|
|
126
|
+
write_func=write_func,
|
|
127
|
+
write_extension=write_extension,
|
|
128
|
+
write_func_kwargs=write_func_kwargs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def select_write_func(
|
|
133
|
+
write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
|
|
134
|
+
) -> WriteFunc:
|
|
135
|
+
"""
|
|
136
|
+
Return a function to write images.
|
|
137
|
+
|
|
138
|
+
If `write_type` is "custom" then `write_func`, otherwise the known write function
|
|
139
|
+
is selected.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
write_type : {"tiff", "custom"}
|
|
144
|
+
The data type to save as, includes custom.
|
|
145
|
+
write_func : WriteFunc, optional
|
|
146
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
147
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
WriteFunc
|
|
152
|
+
A function for writing images.
|
|
153
|
+
|
|
154
|
+
Raises
|
|
155
|
+
------
|
|
156
|
+
ValueError
|
|
157
|
+
If `write_type="custom"` but `write_func` has not been given.
|
|
158
|
+
|
|
159
|
+
Notes
|
|
160
|
+
-----
|
|
161
|
+
The `write_func` function signature must match that of the example below
|
|
162
|
+
```
|
|
163
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
164
|
+
```
|
|
165
|
+
"""
|
|
166
|
+
if write_type == SupportedData.CUSTOM:
|
|
167
|
+
if write_func is None:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"A save function must be provided for custom data types."
|
|
170
|
+
# TODO: link to how save functions should be implemented
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
write_func = write_func
|
|
174
|
+
else:
|
|
175
|
+
write_func = get_write_func(write_type)
|
|
176
|
+
return write_func
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def select_write_extension(
|
|
180
|
+
write_type: SupportedWriteType, write_extension: Optional[str] = None
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Return an extension to add to file paths.
|
|
184
|
+
|
|
185
|
+
If `write_type` is "custom" then `write_extension`, otherwise the known
|
|
186
|
+
write extension is selected.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
write_type : {"tiff", "custom"}
|
|
191
|
+
The data type to save as, includes custom.
|
|
192
|
+
write_extension : str, optional
|
|
193
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
194
|
+
`write_type` an extension to save the data with must be passed.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
str
|
|
199
|
+
The extension to be added to file paths.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If `self.save_type="custom"` but `save_extension` has not been given.
|
|
205
|
+
"""
|
|
206
|
+
write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
|
|
207
|
+
if write_type_ == SupportedData.CUSTOM:
|
|
208
|
+
if write_extension is None:
|
|
209
|
+
raise ValueError("A save extension must be provided for custom data types.")
|
|
210
|
+
else:
|
|
211
|
+
write_extension = write_extension
|
|
212
|
+
else:
|
|
213
|
+
# kind of a weird pattern -> reason to move get_extension from SupportedData
|
|
214
|
+
write_extension = write_type_.get_extension(write_type_)
|
|
215
|
+
return write_extension
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Dict, Union
|
|
5
|
+
|
|
6
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
7
|
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProgressBarCallback(TQDMProgressBar):
|
|
12
|
+
"""Progress bar for training and validation steps."""
|
|
13
|
+
|
|
14
|
+
def init_train_tqdm(self) -> tqdm:
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
22
|
+
bar = tqdm(
|
|
23
|
+
desc="Training",
|
|
24
|
+
position=(2 * self.process_position),
|
|
25
|
+
disable=self.is_disabled,
|
|
26
|
+
leave=True,
|
|
27
|
+
dynamic_ncols=True,
|
|
28
|
+
file=sys.stdout,
|
|
29
|
+
smoothing=0,
|
|
30
|
+
)
|
|
31
|
+
return bar
|
|
32
|
+
|
|
33
|
+
def init_validation_tqdm(self) -> tqdm:
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
41
|
+
# The main progress bar doesn't exist in `trainer.validate()`
|
|
42
|
+
has_main_bar = self.train_progress_bar is not None
|
|
43
|
+
bar = tqdm(
|
|
44
|
+
desc="Validating",
|
|
45
|
+
position=(2 * self.process_position + has_main_bar),
|
|
46
|
+
disable=self.is_disabled,
|
|
47
|
+
leave=False,
|
|
48
|
+
dynamic_ncols=True,
|
|
49
|
+
file=sys.stdout,
|
|
50
|
+
)
|
|
51
|
+
return bar
|
|
52
|
+
|
|
53
|
+
def init_test_tqdm(self) -> tqdm:
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
61
|
+
bar = tqdm(
|
|
62
|
+
desc="Testing",
|
|
63
|
+
position=(2 * self.process_position),
|
|
64
|
+
disable=self.is_disabled,
|
|
65
|
+
leave=True,
|
|
66
|
+
dynamic_ncols=False,
|
|
67
|
+
ncols=100,
|
|
68
|
+
file=sys.stdout,
|
|
69
|
+
)
|
|
70
|
+
return bar
|
|
71
|
+
|
|
72
|
+
def get_metrics(
|
|
73
|
+
self, trainer: Trainer, pl_module: LightningModule
|
|
74
|
+
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
89
|
+
pbar_metrics = trainer.progress_bar_metrics
|
|
90
|
+
return {**pbar_metrics}
|