careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Module to get write functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Protocol
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import write_tiff
|
|
11
|
+
|
|
12
|
+
SupportedWriteType = Literal["tiff", "custom"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# This is very strict, arguments have to be called file_path & img
|
|
16
|
+
# Alternative? - doesn't capture *args & **kwargs
|
|
17
|
+
# WriteFunc = Callable[[Path, NDArray], None]
|
|
18
|
+
class WriteFunc(Protocol):
|
|
19
|
+
"""Protocol for type hinting write functions."""
|
|
20
|
+
|
|
21
|
+
def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Type hinted callables must match this function signature (not including self).
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
file_path : pathlib.Path
|
|
28
|
+
Path to file.
|
|
29
|
+
img : numpy.ndarray
|
|
30
|
+
Image data to save.
|
|
31
|
+
*args
|
|
32
|
+
Other positional arguments.
|
|
33
|
+
**kwargs
|
|
34
|
+
Other keyword arguments.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
|
|
39
|
+
SupportedData.TIFF: write_tiff,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_write_func(data_type: SupportedWriteType) -> WriteFunc:
|
|
44
|
+
"""
|
|
45
|
+
Get the write function for the data type.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
data_type : {"tiff", "custom"}
|
|
50
|
+
Data type.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
callable
|
|
55
|
+
Write function.
|
|
56
|
+
"""
|
|
57
|
+
# error raised here if not supported
|
|
58
|
+
data_type_ = SupportedData(data_type) # new variable for mypy
|
|
59
|
+
# error if no write func.
|
|
60
|
+
if data_type_ not in WRITE_FUNCS:
|
|
61
|
+
raise NotImplementedError(f"No write function for data type '{data_type}'.")
|
|
62
|
+
|
|
63
|
+
return WRITE_FUNCS[data_type_]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Write tiff function."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tifffile
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
13
|
+
# TODO: add link to tiffile docs for args kwrgs?
|
|
14
|
+
"""
|
|
15
|
+
Write tiff files.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
file_path : pathlib.Path
|
|
20
|
+
Path to file.
|
|
21
|
+
img : numpy.ndarray
|
|
22
|
+
Image data to save.
|
|
23
|
+
*args
|
|
24
|
+
Positional arguments passed to `tifffile.imwrite`.
|
|
25
|
+
**kwargs
|
|
26
|
+
Keyword arguments passed to `tifffile.imwrite`.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
ValueError
|
|
31
|
+
When the file extension of `file_path` does not match the Unix shell-style
|
|
32
|
+
pattern '*.tif*'.
|
|
33
|
+
"""
|
|
34
|
+
if not fnmatch(
|
|
35
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
36
|
+
):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
|
|
39
|
+
)
|
|
40
|
+
tifffile.imwrite(file_path, img, *args, **kwargs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""CAREamics PyTorch Lightning modules."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CAREamicsModule",
|
|
5
|
+
"create_careamics_module",
|
|
6
|
+
"TrainDataModule",
|
|
7
|
+
"create_train_datamodule",
|
|
8
|
+
"PredictDataModule",
|
|
9
|
+
"create_predict_datamodule",
|
|
10
|
+
"HyperParametersCallback",
|
|
11
|
+
"ProgressBarCallback",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
15
|
+
from .lightning_module import CAREamicsModule, create_careamics_module
|
|
16
|
+
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
17
|
+
from .train_data_module import TrainDataModule, create_train_datamodule
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Callbacks module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"HyperParametersCallback",
|
|
5
|
+
"ProgressBarCallback",
|
|
6
|
+
"PredictionWriterCallback",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
from .hyperparameters_callback import HyperParametersCallback
|
|
10
|
+
from .prediction_writer_callback import PredictionWriterCallback
|
|
11
|
+
from .progress_bar_callback import ProgressBarCallback
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Callback saving CAREamics configuration as hyperparameters in the model."""
|
|
2
|
+
|
|
3
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
from careamics.config import Configuration
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HyperParametersCallback(Callback):
|
|
10
|
+
"""
|
|
11
|
+
Callback allowing saving CAREamics configuration as hyperparameters in the model.
|
|
12
|
+
|
|
13
|
+
This allows saving the configuration as dictionnary in the checkpoints, and
|
|
14
|
+
loading it subsequently in a CAREamist instance.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
config : Configuration
|
|
19
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
20
|
+
|
|
21
|
+
Attributes
|
|
22
|
+
----------
|
|
23
|
+
config : Configuration
|
|
24
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Configuration) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Constructor.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
config : Configuration
|
|
34
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
35
|
+
"""
|
|
36
|
+
self.config = config
|
|
37
|
+
|
|
38
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Update the hyperparameters of the model with the configuration on train start.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
trainer : Trainer
|
|
45
|
+
PyTorch Lightning trainer, unused.
|
|
46
|
+
pl_module : LightningModule
|
|
47
|
+
PyTorch Lightning module.
|
|
48
|
+
"""
|
|
49
|
+
pl_module.hparams.update(self.config.model_dump())
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""A package for the `PredictionWriterCallback` class and utilities."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"PredictionWriterCallback",
|
|
5
|
+
"create_write_strategy",
|
|
6
|
+
"WriteStrategy",
|
|
7
|
+
"WriteImage",
|
|
8
|
+
"CacheTiles",
|
|
9
|
+
"WriteTilesZarr",
|
|
10
|
+
"select_write_extension",
|
|
11
|
+
"select_write_func",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .prediction_writer_callback import PredictionWriterCallback
|
|
15
|
+
from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
|
|
16
|
+
from .write_strategy_factory import (
|
|
17
|
+
create_write_strategy,
|
|
18
|
+
select_write_extension,
|
|
19
|
+
select_write_func,
|
|
20
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Module containing file path utilities for `WriteStrategy` to use."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TODO: move to datasets package ?
|
|
10
|
+
def get_sample_file_path(
|
|
11
|
+
dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int
|
|
12
|
+
) -> Path:
|
|
13
|
+
"""
|
|
14
|
+
Get the file path for a particular sample.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
dataset : IterableTiledPredDataset or IterablePredDataset
|
|
19
|
+
Dataset.
|
|
20
|
+
sample_id : int
|
|
21
|
+
Sample ID, the index of the file in the dataset `dataset`.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
Path
|
|
26
|
+
The file path corresponding to the sample with the ID `sample_id`.
|
|
27
|
+
"""
|
|
28
|
+
return dataset.data_files[sample_id]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_write_file_path(
|
|
32
|
+
dirpath: Path, file_path: Path, write_extension: str
|
|
33
|
+
) -> Path:
|
|
34
|
+
"""
|
|
35
|
+
Create the file name for the output file.
|
|
36
|
+
|
|
37
|
+
Takes the original file path, changes the directory to `dirpath` and changes
|
|
38
|
+
the extension to `write_extension`.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
dirpath : pathlib.Path
|
|
43
|
+
The output directory to write file to.
|
|
44
|
+
file_path : pathlib.Path
|
|
45
|
+
The original file path.
|
|
46
|
+
write_extension : str
|
|
47
|
+
The extension that output files should have.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
Path
|
|
52
|
+
The output file path.
|
|
53
|
+
"""
|
|
54
|
+
file_name = Path(file_path.stem).with_suffix(write_extension)
|
|
55
|
+
file_path = dirpath / file_name
|
|
56
|
+
return file_path
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Module containing `PredictionWriterCallback` class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional, Sequence, Union
|
|
7
|
+
|
|
8
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
9
|
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from careamics.dataset import (
|
|
13
|
+
IterablePredDataset,
|
|
14
|
+
IterableTiledPredDataset,
|
|
15
|
+
)
|
|
16
|
+
from careamics.file_io import SupportedWriteType, WriteFunc
|
|
17
|
+
from careamics.utils import get_logger
|
|
18
|
+
|
|
19
|
+
from .write_strategy import WriteStrategy
|
|
20
|
+
from .write_strategy_factory import create_write_strategy
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
ValidPredDatasets = Union[IterablePredDataset, IterableTiledPredDataset]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PredictionWriterCallback(BasePredictionWriter):
|
|
28
|
+
"""
|
|
29
|
+
A PyTorch Lightning callback to save predictions.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
write_strategy : WriteStrategy
|
|
34
|
+
A strategy for writing predictions.
|
|
35
|
+
dirpath : Path or str, default="predictions"
|
|
36
|
+
The path to the directory where prediction outputs will be saved. If
|
|
37
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
38
|
+
directory.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
write_strategy : WriteStrategy
|
|
43
|
+
A strategy for writing predictions.
|
|
44
|
+
dirpath : pathlib.Path, default="predictions"
|
|
45
|
+
The path to the directory where prediction outputs will be saved. If
|
|
46
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
47
|
+
directory.
|
|
48
|
+
writing_predictions : bool
|
|
49
|
+
If writing predictions is turned on or off.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
write_strategy: WriteStrategy,
|
|
55
|
+
dirpath: Union[Path, str] = "predictions",
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
A PyTorch Lightning callback to save predictions.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
write_strategy : WriteStrategy
|
|
63
|
+
A strategy for writing predictions.
|
|
64
|
+
dirpath : pathlib.Path or str, default="predictions"
|
|
65
|
+
The path to the directory where prediction outputs will be saved. If
|
|
66
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
67
|
+
directory.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(write_interval="batch")
|
|
70
|
+
|
|
71
|
+
# Toggle for CAREamist to switch off saving if desired
|
|
72
|
+
self.writing_predictions: bool = True
|
|
73
|
+
|
|
74
|
+
self.write_strategy: WriteStrategy = write_strategy
|
|
75
|
+
|
|
76
|
+
# forward declaration
|
|
77
|
+
self.dirpath: Path
|
|
78
|
+
# attribute initialisation
|
|
79
|
+
self._init_dirpath(dirpath)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_write_func_params(
|
|
83
|
+
cls,
|
|
84
|
+
write_type: SupportedWriteType,
|
|
85
|
+
tiled: bool,
|
|
86
|
+
write_func: Optional[WriteFunc] = None,
|
|
87
|
+
write_extension: Optional[str] = None,
|
|
88
|
+
write_func_kwargs: Optional[dict[str, Any]] = None,
|
|
89
|
+
dirpath: Union[Path, str] = "predictions",
|
|
90
|
+
) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
|
|
91
|
+
"""
|
|
92
|
+
Initialize a `PredictionWriterCallback` from write function parameters.
|
|
93
|
+
|
|
94
|
+
This will automatically create a `WriteStrategy` to be passed to the
|
|
95
|
+
initialization of `PredictionWriterCallback`.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
write_type : {"tiff", "custom"}
|
|
100
|
+
The data type to save as, includes custom.
|
|
101
|
+
tiled : bool
|
|
102
|
+
Whether the prediction will be tiled or not.
|
|
103
|
+
write_func : WriteFunc, optional
|
|
104
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
105
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
106
|
+
write_extension : str, optional
|
|
107
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
108
|
+
`write_type` an extension to save the data with must be passed.
|
|
109
|
+
write_func_kwargs : dict of {{str: any}}, optional
|
|
110
|
+
Additional keyword arguments to be passed to the save function.
|
|
111
|
+
dirpath : pathlib.Path or str, default="predictions"
|
|
112
|
+
The path to the directory where prediction outputs will be saved. If
|
|
113
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
114
|
+
directory.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
PredictionWriterCallback
|
|
119
|
+
Callback for writing predictions.
|
|
120
|
+
"""
|
|
121
|
+
write_strategy = create_write_strategy(
|
|
122
|
+
write_type=write_type,
|
|
123
|
+
tiled=tiled,
|
|
124
|
+
write_func=write_func,
|
|
125
|
+
write_extension=write_extension,
|
|
126
|
+
write_func_kwargs=write_func_kwargs,
|
|
127
|
+
)
|
|
128
|
+
return cls(write_strategy=write_strategy, dirpath=dirpath)
|
|
129
|
+
|
|
130
|
+
def _init_dirpath(self, dirpath):
|
|
131
|
+
"""
|
|
132
|
+
Initialize directory path. Should only be called from `__init__`.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
dirpath : pathlib.Path
|
|
137
|
+
See `__init__` description.
|
|
138
|
+
"""
|
|
139
|
+
dirpath = Path(dirpath)
|
|
140
|
+
if not dirpath.is_absolute():
|
|
141
|
+
dirpath = Path.cwd() / dirpath
|
|
142
|
+
logger.warning(
|
|
143
|
+
"Prediction output directory is not absolute, absolute path assumed to"
|
|
144
|
+
f"be '{dirpath}'"
|
|
145
|
+
)
|
|
146
|
+
self.dirpath = dirpath
|
|
147
|
+
|
|
148
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Create the prediction output directory when predict begins.
|
|
151
|
+
|
|
152
|
+
Called when fit, validate, test, predict, or tune begins.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
trainer : Trainer
|
|
157
|
+
PyTorch Lightning trainer.
|
|
158
|
+
pl_module : LightningModule
|
|
159
|
+
PyTorch Lightning module.
|
|
160
|
+
stage : str
|
|
161
|
+
Stage of training e.g. 'predict', 'fit', 'validate'.
|
|
162
|
+
"""
|
|
163
|
+
super().setup(trainer, pl_module, stage)
|
|
164
|
+
if stage == "predict":
|
|
165
|
+
# make prediction output directory
|
|
166
|
+
logger.info("Making prediction output directory.")
|
|
167
|
+
self.dirpath.mkdir(parents=True, exist_ok=True)
|
|
168
|
+
|
|
169
|
+
def write_on_batch_end(
|
|
170
|
+
self,
|
|
171
|
+
trainer: Trainer,
|
|
172
|
+
pl_module: LightningModule,
|
|
173
|
+
prediction: Any, # TODO: change to expected type
|
|
174
|
+
batch_indices: Optional[Sequence[int]],
|
|
175
|
+
batch: Any, # TODO: change to expected type
|
|
176
|
+
batch_idx: int,
|
|
177
|
+
dataloader_idx: int,
|
|
178
|
+
) -> None:
|
|
179
|
+
"""
|
|
180
|
+
Write predictions at the end of a batch.
|
|
181
|
+
|
|
182
|
+
The method of prediction is determined by the attribute `write_strategy`.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
trainer : Trainer
|
|
187
|
+
PyTorch Lightning trainer.
|
|
188
|
+
pl_module : LightningModule
|
|
189
|
+
PyTorch Lightning module.
|
|
190
|
+
prediction : Any
|
|
191
|
+
Prediction outputs of `batch`.
|
|
192
|
+
batch_indices : sequence of Any, optional
|
|
193
|
+
Batch indices.
|
|
194
|
+
batch : Any
|
|
195
|
+
Input batch.
|
|
196
|
+
batch_idx : int
|
|
197
|
+
Batch index.
|
|
198
|
+
dataloader_idx : int
|
|
199
|
+
Dataloader index.
|
|
200
|
+
"""
|
|
201
|
+
# if writing prediction is turned off
|
|
202
|
+
if not self.writing_predictions:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
206
|
+
dataloader: DataLoader = (
|
|
207
|
+
dataloaders[dataloader_idx]
|
|
208
|
+
if isinstance(dataloaders, list)
|
|
209
|
+
else dataloaders
|
|
210
|
+
)
|
|
211
|
+
dataset: ValidPredDatasets = dataloader.dataset
|
|
212
|
+
if not (
|
|
213
|
+
isinstance(dataset, IterablePredDataset)
|
|
214
|
+
or isinstance(dataset, IterableTiledPredDataset)
|
|
215
|
+
):
|
|
216
|
+
# Note: Error will be raised before here from the source type
|
|
217
|
+
# This is for extra redundancy of errors.
|
|
218
|
+
raise TypeError(
|
|
219
|
+
"Prediction dataset has to be `IterableTiledPredDataset` or "
|
|
220
|
+
"`IterablePredDataset`. Cannot be `InMemoryPredDataset` because "
|
|
221
|
+
"filenames are taken from the original file."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self.write_strategy.write_batch(
|
|
225
|
+
trainer=trainer,
|
|
226
|
+
pl_module=pl_module,
|
|
227
|
+
prediction=prediction,
|
|
228
|
+
batch_indices=batch_indices,
|
|
229
|
+
batch=batch,
|
|
230
|
+
batch_idx=batch_idx,
|
|
231
|
+
dataloader_idx=dataloader_idx,
|
|
232
|
+
dirpath=self.dirpath,
|
|
233
|
+
)
|