careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,398 @@
1
+ """Module containing different strategies for writing predictions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Optional, Protocol, Sequence, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from pytorch_lightning import LightningModule, Trainer
9
+ from torch.utils.data import DataLoader
10
+
11
+ from careamics.config.tile_information import TileInformation
12
+ from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
13
+ from careamics.file_io import WriteFunc
14
+ from careamics.prediction_utils import stitch_prediction_single
15
+
16
+ from .file_path_utils import create_write_file_path, get_sample_file_path
17
+
18
+
19
+ class WriteStrategy(Protocol):
20
+ """Protocol for write strategy classes."""
21
+
22
+ def write_batch(
23
+ self,
24
+ trainer: Trainer,
25
+ pl_module: LightningModule,
26
+ prediction: Any, # TODO: change to expected type
27
+ batch_indices: Optional[Sequence[int]],
28
+ batch: Any, # TODO: change to expected type
29
+ batch_idx: int,
30
+ dataloader_idx: int,
31
+ dirpath: Path,
32
+ ) -> None:
33
+ """
34
+ WriteStrategy subclasses must contain this function to write a batch.
35
+
36
+ Parameters
37
+ ----------
38
+ trainer : Trainer
39
+ PyTorch Lightning Trainer.
40
+ pl_module : LightningModule
41
+ PyTorch Lightning LightningModule.
42
+ prediction : Any
43
+ Predictions on `batch`.
44
+ batch_indices : sequence of int
45
+ Indices identifying the samples in the batch.
46
+ batch : Any
47
+ Input batch.
48
+ batch_idx : int
49
+ Batch index.
50
+ dataloader_idx : int
51
+ Dataloader index.
52
+ dirpath : Path
53
+ Path to directory to save predictions to.
54
+ """
55
+
56
+
57
+ class CacheTiles(WriteStrategy):
58
+ """
59
+ A write strategy that will cache tiles.
60
+
61
+ Tiles are cached until a whole image is predicted on. Then the stitched
62
+ prediction is saved.
63
+
64
+ Parameters
65
+ ----------
66
+ write_func : WriteFunc
67
+ Function used to save predictions.
68
+ write_extension : str
69
+ Extension added to prediction file paths.
70
+ write_func_kwargs : dict of {str: Any}
71
+ Extra kwargs to pass to `write_func`.
72
+
73
+ Attributes
74
+ ----------
75
+ write_func : WriteFunc
76
+ Function used to save predictions.
77
+ write_extension : str
78
+ Extension added to prediction file paths.
79
+ write_func_kwargs : dict of {str: Any}
80
+ Extra kwargs to pass to `write_func`.
81
+ tile_cache : list of numpy.ndarray
82
+ Tiles cached for stitching prediction.
83
+ tile_info_cache : list of TileInformation
84
+ Cached tile information for stitching prediction.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ write_func: WriteFunc,
90
+ write_extension: str,
91
+ write_func_kwargs: dict[str, Any],
92
+ ) -> None:
93
+ """
94
+ A write strategy that will cache tiles.
95
+
96
+ Tiles are cached until a whole image is predicted on. Then the stitched
97
+ prediction is saved.
98
+
99
+ Parameters
100
+ ----------
101
+ write_func : WriteFunc
102
+ Function used to save predictions.
103
+ write_extension : str
104
+ Extension added to prediction file paths.
105
+ write_func_kwargs : dict of {str: Any}
106
+ Extra kwargs to pass to `write_func`.
107
+ """
108
+ super().__init__()
109
+
110
+ self.write_func: WriteFunc = write_func
111
+ self.write_extension: str = write_extension
112
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
113
+
114
+ # where tiles will be cached until a whole image has been predicted
115
+ self.tile_cache: list[NDArray] = []
116
+ self.tile_info_cache: list[TileInformation] = []
117
+
118
+ @property
119
+ def last_tiles(self) -> list[bool]:
120
+ """
121
+ List of bool to determine whether each tile in the cache is the last tile.
122
+
123
+ Returns
124
+ -------
125
+ list of bool
126
+ Whether each tile in the tile cache is the last tile.
127
+ """
128
+ return [tile_info.last_tile for tile_info in self.tile_info_cache]
129
+
130
+ def write_batch(
131
+ self,
132
+ trainer: Trainer,
133
+ pl_module: LightningModule,
134
+ prediction: tuple[NDArray, list[TileInformation]],
135
+ batch_indices: Optional[Sequence[int]],
136
+ batch: tuple[NDArray, list[TileInformation]],
137
+ batch_idx: int,
138
+ dataloader_idx: int,
139
+ dirpath: Path,
140
+ ) -> None:
141
+ """
142
+ Cache tiles until the last tile is predicted; save the stitched prediction.
143
+
144
+ Parameters
145
+ ----------
146
+ trainer : Trainer
147
+ PyTorch Lightning Trainer.
148
+ pl_module : LightningModule
149
+ PyTorch Lightning LightningModule.
150
+ prediction : Any
151
+ Predictions on `batch`.
152
+ batch_indices : sequence of int
153
+ Indices identifying the samples in the batch.
154
+ batch : Any
155
+ Input batch.
156
+ batch_idx : int
157
+ Batch index.
158
+ dataloader_idx : int
159
+ Dataloader index.
160
+ dirpath : Path
161
+ Path to directory to save predictions to.
162
+ """
163
+ dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
164
+ dataloader: DataLoader = (
165
+ dataloaders[dataloader_idx]
166
+ if isinstance(dataloaders, list)
167
+ else dataloaders
168
+ )
169
+ dataset: IterableTiledPredDataset = dataloader.dataset
170
+ if not isinstance(dataset, IterableTiledPredDataset):
171
+ raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.")
172
+
173
+ # cache tiles (batches are split into single samples)
174
+ self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0]))
175
+ self.tile_info_cache.extend(prediction[1])
176
+
177
+ # save stitched prediction
178
+ if self._has_last_tile():
179
+
180
+ # get image tiles and remove them from the cache
181
+ tiles, tile_infos = self._get_image_tiles()
182
+ self._clear_cache()
183
+
184
+ # stitch prediction
185
+ prediction_image = stitch_prediction_single(
186
+ tiles=tiles, tile_infos=tile_infos
187
+ )
188
+
189
+ # write prediction
190
+ sample_id = tile_infos[0].sample_id # need this to select correct file name
191
+ input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id)
192
+ file_path = create_write_file_path(
193
+ dirpath=dirpath,
194
+ file_path=input_file_path,
195
+ write_extension=self.write_extension,
196
+ )
197
+ self.write_func(
198
+ file_path=file_path, img=prediction_image[0], **self.write_func_kwargs
199
+ )
200
+
201
+ def _has_last_tile(self) -> bool:
202
+ """
203
+ Whether a last tile is contained in the cached tiles.
204
+
205
+ Returns
206
+ -------
207
+ bool
208
+ Whether a last tile is contained in the cached tiles.
209
+ """
210
+ return any(self.last_tiles)
211
+
212
+ def _clear_cache(self) -> None:
213
+ """Remove the tiles in the cache up to the first last tile."""
214
+ index = self._last_tile_index()
215
+ self.tile_cache = self.tile_cache[index + 1 :]
216
+ self.tile_info_cache = self.tile_info_cache[index + 1 :]
217
+
218
+ def _last_tile_index(self) -> int:
219
+ """
220
+ Find the index of the last tile in the tile cache.
221
+
222
+ Returns
223
+ -------
224
+ int
225
+ Index of last tile.
226
+
227
+ Raises
228
+ ------
229
+ ValueError
230
+ If there is no last tile in the tile cache.
231
+ """
232
+ last_tiles = self.last_tiles
233
+ if not any(last_tiles):
234
+ raise ValueError("No last tile in the tile cache.")
235
+ index = np.where(last_tiles)[0][0]
236
+ return index
237
+
238
+ def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]:
239
+ """
240
+ Get the tiles corresponding to a single image.
241
+
242
+ Returns
243
+ -------
244
+ tuple of (list of numpy.ndarray, list of TileInformation)
245
+ Tiles and tile information to stitch together a full image.
246
+ """
247
+ index = self._last_tile_index()
248
+ tiles = self.tile_cache[: index + 1]
249
+ tile_infos = self.tile_info_cache[: index + 1]
250
+ return tiles, tile_infos
251
+
252
+
253
+ class WriteTilesZarr(WriteStrategy):
254
+ """Strategy to write tiles to Zarr file."""
255
+
256
+ def write_batch(
257
+ self,
258
+ trainer: Trainer,
259
+ pl_module: LightningModule,
260
+ prediction: Any,
261
+ batch_indices: Optional[Sequence[int]],
262
+ batch: Any,
263
+ batch_idx: int,
264
+ dataloader_idx: int,
265
+ dirpath: Path,
266
+ ) -> None:
267
+ """
268
+ Write tiles to zarr file.
269
+
270
+ Parameters
271
+ ----------
272
+ trainer : Trainer
273
+ PyTorch Lightning Trainer.
274
+ pl_module : LightningModule
275
+ PyTorch Lightning LightningModule.
276
+ prediction : Any
277
+ Predictions on `batch`.
278
+ batch_indices : sequence of int
279
+ Indices identifying the samples in the batch.
280
+ batch : Any
281
+ Input batch.
282
+ batch_idx : int
283
+ Batch index.
284
+ dataloader_idx : int
285
+ Dataloader index.
286
+ dirpath : Path
287
+ Path to directory to save predictions to.
288
+
289
+ Raises
290
+ ------
291
+ NotImplementedError
292
+ """
293
+ raise NotImplementedError
294
+
295
+
296
+ class WriteImage(WriteStrategy):
297
+ """
298
+ A strategy for writing image predictions (i.e. un-tiled predictions).
299
+
300
+ Parameters
301
+ ----------
302
+ write_func : WriteFunc
303
+ Function used to save predictions.
304
+ write_extension : str
305
+ Extension added to prediction file paths.
306
+ write_func_kwargs : dict of {str: Any}
307
+ Extra kwargs to pass to `write_func`.
308
+
309
+ Attributes
310
+ ----------
311
+ write_func : WriteFunc
312
+ Function used to save predictions.
313
+ write_extension : str
314
+ Extension added to prediction file paths.
315
+ write_func_kwargs : dict of {str: Any}
316
+ Extra kwargs to pass to `write_func`.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ write_func: WriteFunc,
322
+ write_extension: str,
323
+ write_func_kwargs: dict[str, Any],
324
+ ) -> None:
325
+ """
326
+ A strategy for writing image predictions (i.e. un-tiled predictions).
327
+
328
+ Parameters
329
+ ----------
330
+ write_func : WriteFunc
331
+ Function used to save predictions.
332
+ write_extension : str
333
+ Extension added to prediction file paths.
334
+ write_func_kwargs : dict of {str: Any}
335
+ Extra kwargs to pass to `write_func`.
336
+ """
337
+ super().__init__()
338
+
339
+ self.write_func: WriteFunc = write_func
340
+ self.write_extension: str = write_extension
341
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
342
+
343
+ def write_batch(
344
+ self,
345
+ trainer: Trainer,
346
+ pl_module: LightningModule,
347
+ prediction: NDArray,
348
+ batch_indices: Optional[Sequence[int]],
349
+ batch: NDArray,
350
+ batch_idx: int,
351
+ dataloader_idx: int,
352
+ dirpath: Path,
353
+ ) -> None:
354
+ """
355
+ Save full images.
356
+
357
+ Parameters
358
+ ----------
359
+ trainer : Trainer
360
+ PyTorch Lightning Trainer.
361
+ pl_module : LightningModule
362
+ PyTorch Lightning LightningModule.
363
+ prediction : Any
364
+ Predictions on `batch`.
365
+ batch_indices : sequence of int
366
+ Indices identifying the samples in the batch.
367
+ batch : Any
368
+ Input batch.
369
+ batch_idx : int
370
+ Batch index.
371
+ dataloader_idx : int
372
+ Dataloader index.
373
+ dirpath : Path
374
+ Path to directory to save predictions to.
375
+
376
+ Raises
377
+ ------
378
+ TypeError
379
+ If trainer prediction dataset is not `IterablePredDataset`.
380
+ """
381
+ dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
382
+ dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls
383
+ ds: IterablePredDataset = dl.dataset
384
+ if not isinstance(ds, IterablePredDataset):
385
+ raise TypeError("Prediction dataset is not `IterablePredDataset`.")
386
+
387
+ for i in range(prediction.shape[0]):
388
+ prediction_image = prediction[0]
389
+ sample_id = batch_idx * dl.batch_size + i
390
+ input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id)
391
+ file_path = create_write_file_path(
392
+ dirpath=dirpath,
393
+ file_path=input_file_path,
394
+ write_extension=self.write_extension,
395
+ )
396
+ self.write_func(
397
+ file_path=file_path, img=prediction_image, **self.write_func_kwargs
398
+ )
@@ -0,0 +1,215 @@
1
+ """Module containing convienience function to create `WriteStrategy`."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ from careamics.config.support import SupportedData
6
+ from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
7
+
8
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
9
+
10
+
11
+ def create_write_strategy(
12
+ write_type: SupportedWriteType,
13
+ tiled: bool,
14
+ write_func: Optional[WriteFunc] = None,
15
+ write_extension: Optional[str] = None,
16
+ write_func_kwargs: Optional[dict[str, Any]] = None,
17
+ ) -> WriteStrategy:
18
+ """
19
+ Create a write strategy from convenient parameters.
20
+
21
+ Parameters
22
+ ----------
23
+ write_type : {"tiff", "custom"}
24
+ The data type to save as, includes custom.
25
+ tiled : bool
26
+ Whether the prediction will be tiled or not.
27
+ write_func : WriteFunc, optional
28
+ If a known `write_type` is selected this argument is ignored. For a custom
29
+ `write_type` a function to save the data must be passed. See notes below.
30
+ write_extension : str, optional
31
+ If a known `write_type` is selected this argument is ignored. For a custom
32
+ `write_type` an extension to save the data with must be passed.
33
+ write_func_kwargs : dict of {str: any}, optional
34
+ Additional keyword arguments to be passed to the save function.
35
+
36
+ Returns
37
+ -------
38
+ WriteStrategy
39
+ A strategy for writing predicions.
40
+
41
+ Notes
42
+ -----
43
+ The `write_func` function signature must match that of the example below
44
+ ```
45
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
46
+ ```
47
+
48
+ The `write_func_kwargs` will be passed to the `write_func` doing the following:
49
+ ```
50
+ write_func(file_path=file_path, img=img, **kwargs)
51
+ ```
52
+ """
53
+ if write_func_kwargs is None:
54
+ write_func_kwargs = {}
55
+
56
+ write_strategy: WriteStrategy
57
+ if not tiled:
58
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
59
+ write_extension = select_write_extension(
60
+ write_type=write_type, write_extension=write_extension
61
+ )
62
+ write_strategy = WriteImage(
63
+ write_func=write_func,
64
+ write_extension=write_extension,
65
+ write_func_kwargs=write_func_kwargs,
66
+ )
67
+ else:
68
+ # select CacheTiles or WriteTilesZarr (when implemented)
69
+ write_strategy = _create_tiled_write_strategy(
70
+ write_type=write_type,
71
+ write_func=write_func,
72
+ write_extension=write_extension,
73
+ write_func_kwargs=write_func_kwargs,
74
+ )
75
+
76
+ return write_strategy
77
+
78
+
79
+ def _create_tiled_write_strategy(
80
+ write_type: SupportedWriteType,
81
+ write_func: Optional[WriteFunc],
82
+ write_extension: Optional[str],
83
+ write_func_kwargs: dict[str, Any],
84
+ ) -> WriteStrategy:
85
+ """
86
+ Create a tiled write strategy.
87
+
88
+ Either `CacheTiles` for caching tiles until a whole image is predicted or
89
+ `WriteTilesZarr` for writing tiles directly to disk.
90
+
91
+ Parameters
92
+ ----------
93
+ write_type : {"tiff", "custom"}
94
+ The data type to save as, includes custom.
95
+ write_func : WriteFunc, optional
96
+ If a known `write_type` is selected this argument is ignored. For a custom
97
+ `write_type` a function to save the data must be passed. See notes below.
98
+ write_extension : str, optional
99
+ If a known `write_type` is selected this argument is ignored. For a custom
100
+ `write_type` an extension to save the data with must be passed.
101
+ write_func_kwargs : dict of {str: any}
102
+ Additional keyword arguments to be passed to the save function.
103
+
104
+ Returns
105
+ -------
106
+ WriteStrategy
107
+ A strategy for writing tiled predictions.
108
+
109
+ Raises
110
+ ------
111
+ NotImplementedError
112
+ if `write_type="zarr" is chosen.
113
+ """
114
+ # if write_type == SupportedData.ZARR:
115
+ # create *args, **kwargs
116
+ # return WriteTilesZarr(*args, **kwargs)
117
+ # else:
118
+ if write_type == "zarr":
119
+ raise NotImplementedError("Saving to zarr is not implemented yet.")
120
+ else:
121
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
122
+ write_extension = select_write_extension(
123
+ write_type=write_type, write_extension=write_extension
124
+ )
125
+ return CacheTiles(
126
+ write_func=write_func,
127
+ write_extension=write_extension,
128
+ write_func_kwargs=write_func_kwargs,
129
+ )
130
+
131
+
132
+ def select_write_func(
133
+ write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
134
+ ) -> WriteFunc:
135
+ """
136
+ Return a function to write images.
137
+
138
+ If `write_type` is "custom" then `write_func`, otherwise the known write function
139
+ is selected.
140
+
141
+ Parameters
142
+ ----------
143
+ write_type : {"tiff", "custom"}
144
+ The data type to save as, includes custom.
145
+ write_func : WriteFunc, optional
146
+ If a known `write_type` is selected this argument is ignored. For a custom
147
+ `write_type` a function to save the data must be passed. See notes below.
148
+
149
+ Returns
150
+ -------
151
+ WriteFunc
152
+ A function for writing images.
153
+
154
+ Raises
155
+ ------
156
+ ValueError
157
+ If `write_type="custom"` but `write_func` has not been given.
158
+
159
+ Notes
160
+ -----
161
+ The `write_func` function signature must match that of the example below
162
+ ```
163
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
164
+ ```
165
+ """
166
+ if write_type == SupportedData.CUSTOM:
167
+ if write_func is None:
168
+ raise ValueError(
169
+ "A save function must be provided for custom data types."
170
+ # TODO: link to how save functions should be implemented
171
+ )
172
+ else:
173
+ write_func = write_func
174
+ else:
175
+ write_func = get_write_func(write_type)
176
+ return write_func
177
+
178
+
179
+ def select_write_extension(
180
+ write_type: SupportedWriteType, write_extension: Optional[str] = None
181
+ ) -> str:
182
+ """
183
+ Return an extension to add to file paths.
184
+
185
+ If `write_type` is "custom" then `write_extension`, otherwise the known
186
+ write extension is selected.
187
+
188
+ Parameters
189
+ ----------
190
+ write_type : {"tiff", "custom"}
191
+ The data type to save as, includes custom.
192
+ write_extension : str, optional
193
+ If a known `write_type` is selected this argument is ignored. For a custom
194
+ `write_type` an extension to save the data with must be passed.
195
+
196
+ Returns
197
+ -------
198
+ str
199
+ The extension to be added to file paths.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If `self.save_type="custom"` but `save_extension` has not been given.
205
+ """
206
+ write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
207
+ if write_type_ == SupportedData.CUSTOM:
208
+ if write_extension is None:
209
+ raise ValueError("A save extension must be provided for custom data types.")
210
+ else:
211
+ write_extension = write_extension
212
+ else:
213
+ # kind of a weird pattern -> reason to move get_extension from SupportedData
214
+ write_extension = write_type_.get_extension(write_type_)
215
+ return write_extension
@@ -0,0 +1,90 @@
1
+ """Progressbar callback."""
2
+
3
+ import sys
4
+ from typing import Dict, Union
5
+
6
+ from pytorch_lightning import LightningModule, Trainer
7
+ from pytorch_lightning.callbacks import TQDMProgressBar
8
+ from tqdm import tqdm
9
+
10
+
11
+ class ProgressBarCallback(TQDMProgressBar):
12
+ """Progress bar for training and validation steps."""
13
+
14
+ def init_train_tqdm(self) -> tqdm:
15
+ """Override this to customize the tqdm bar for training.
16
+
17
+ Returns
18
+ -------
19
+ tqdm
20
+ A tqdm bar.
21
+ """
22
+ bar = tqdm(
23
+ desc="Training",
24
+ position=(2 * self.process_position),
25
+ disable=self.is_disabled,
26
+ leave=True,
27
+ dynamic_ncols=True,
28
+ file=sys.stdout,
29
+ smoothing=0,
30
+ )
31
+ return bar
32
+
33
+ def init_validation_tqdm(self) -> tqdm:
34
+ """Override this to customize the tqdm bar for validation.
35
+
36
+ Returns
37
+ -------
38
+ tqdm
39
+ A tqdm bar.
40
+ """
41
+ # The main progress bar doesn't exist in `trainer.validate()`
42
+ has_main_bar = self.train_progress_bar is not None
43
+ bar = tqdm(
44
+ desc="Validating",
45
+ position=(2 * self.process_position + has_main_bar),
46
+ disable=self.is_disabled,
47
+ leave=False,
48
+ dynamic_ncols=True,
49
+ file=sys.stdout,
50
+ )
51
+ return bar
52
+
53
+ def init_test_tqdm(self) -> tqdm:
54
+ """Override this to customize the tqdm bar for testing.
55
+
56
+ Returns
57
+ -------
58
+ tqdm
59
+ A tqdm bar.
60
+ """
61
+ bar = tqdm(
62
+ desc="Testing",
63
+ position=(2 * self.process_position),
64
+ disable=self.is_disabled,
65
+ leave=True,
66
+ dynamic_ncols=False,
67
+ ncols=100,
68
+ file=sys.stdout,
69
+ )
70
+ return bar
71
+
72
+ def get_metrics(
73
+ self, trainer: Trainer, pl_module: LightningModule
74
+ ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
75
+ """Override this to customize the metrics displayed in the progress bar.
76
+
77
+ Parameters
78
+ ----------
79
+ trainer : Trainer
80
+ The trainer object.
81
+ pl_module : LightningModule
82
+ The LightningModule object, unused.
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary with the metrics to display in the progress bar.
88
+ """
89
+ pbar_metrics = trainer.progress_bar_metrics
90
+ return {**pbar_metrics}