careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,63 @@
1
+ """Module to get write functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Protocol
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import write_tiff
11
+
12
+ SupportedWriteType = Literal["tiff", "custom"]
13
+
14
+
15
+ # This is very strict, arguments have to be called file_path & img
16
+ # Alternative? - doesn't capture *args & **kwargs
17
+ # WriteFunc = Callable[[Path, NDArray], None]
18
+ class WriteFunc(Protocol):
19
+ """Protocol for type hinting write functions."""
20
+
21
+ def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
22
+ """
23
+ Type hinted callables must match this function signature (not including self).
24
+
25
+ Parameters
26
+ ----------
27
+ file_path : pathlib.Path
28
+ Path to file.
29
+ img : numpy.ndarray
30
+ Image data to save.
31
+ *args
32
+ Other positional arguments.
33
+ **kwargs
34
+ Other keyword arguments.
35
+ """
36
+
37
+
38
+ WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
39
+ SupportedData.TIFF: write_tiff,
40
+ }
41
+
42
+
43
+ def get_write_func(data_type: SupportedWriteType) -> WriteFunc:
44
+ """
45
+ Get the write function for the data type.
46
+
47
+ Parameters
48
+ ----------
49
+ data_type : {"tiff", "custom"}
50
+ Data type.
51
+
52
+ Returns
53
+ -------
54
+ callable
55
+ Write function.
56
+ """
57
+ # error raised here if not supported
58
+ data_type_ = SupportedData(data_type) # new variable for mypy
59
+ # error if no write func.
60
+ if data_type_ not in WRITE_FUNCS:
61
+ raise NotImplementedError(f"No write function for data type '{data_type}'.")
62
+
63
+ return WRITE_FUNCS[data_type_]
@@ -0,0 +1,40 @@
1
+ """Write tiff function."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+
6
+ import tifffile
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+
12
+ def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
13
+ # TODO: add link to tiffile docs for args kwrgs?
14
+ """
15
+ Write tiff files.
16
+
17
+ Parameters
18
+ ----------
19
+ file_path : pathlib.Path
20
+ Path to file.
21
+ img : numpy.ndarray
22
+ Image data to save.
23
+ *args
24
+ Positional arguments passed to `tifffile.imwrite`.
25
+ **kwargs
26
+ Keyword arguments passed to `tifffile.imwrite`.
27
+
28
+ Raises
29
+ ------
30
+ ValueError
31
+ When the file extension of `file_path` does not match the Unix shell-style
32
+ pattern '*.tif*'.
33
+ """
34
+ if not fnmatch(
35
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
36
+ ):
37
+ raise ValueError(
38
+ f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
39
+ )
40
+ tifffile.imwrite(file_path, img, *args, **kwargs)
@@ -0,0 +1,17 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ __all__ = [
4
+ "CAREamicsModule",
5
+ "create_careamics_module",
6
+ "TrainDataModule",
7
+ "create_train_datamodule",
8
+ "PredictDataModule",
9
+ "create_predict_datamodule",
10
+ "HyperParametersCallback",
11
+ "ProgressBarCallback",
12
+ ]
13
+
14
+ from .callbacks import HyperParametersCallback, ProgressBarCallback
15
+ from .lightning_module import CAREamicsModule, create_careamics_module
16
+ from .predict_data_module import PredictDataModule, create_predict_datamodule
17
+ from .train_data_module import TrainDataModule, create_train_datamodule
@@ -0,0 +1,11 @@
1
+ """Callbacks module."""
2
+
3
+ __all__ = [
4
+ "HyperParametersCallback",
5
+ "ProgressBarCallback",
6
+ "PredictionWriterCallback",
7
+ ]
8
+
9
+ from .hyperparameters_callback import HyperParametersCallback
10
+ from .prediction_writer_callback import PredictionWriterCallback
11
+ from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,49 @@
1
+ """Callback saving CAREamics configuration as hyperparameters in the model."""
2
+
3
+ from pytorch_lightning import LightningModule, Trainer
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+ from careamics.config import Configuration
7
+
8
+
9
+ class HyperParametersCallback(Callback):
10
+ """
11
+ Callback allowing saving CAREamics configuration as hyperparameters in the model.
12
+
13
+ This allows saving the configuration as dictionnary in the checkpoints, and
14
+ loading it subsequently in a CAREamist instance.
15
+
16
+ Parameters
17
+ ----------
18
+ config : Configuration
19
+ CAREamics configuration to be saved as hyperparameter in the model.
20
+
21
+ Attributes
22
+ ----------
23
+ config : Configuration
24
+ CAREamics configuration to be saved as hyperparameter in the model.
25
+ """
26
+
27
+ def __init__(self, config: Configuration) -> None:
28
+ """
29
+ Constructor.
30
+
31
+ Parameters
32
+ ----------
33
+ config : Configuration
34
+ CAREamics configuration to be saved as hyperparameter in the model.
35
+ """
36
+ self.config = config
37
+
38
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
39
+ """
40
+ Update the hyperparameters of the model with the configuration on train start.
41
+
42
+ Parameters
43
+ ----------
44
+ trainer : Trainer
45
+ PyTorch Lightning trainer, unused.
46
+ pl_module : LightningModule
47
+ PyTorch Lightning module.
48
+ """
49
+ pl_module.hparams.update(self.config.model_dump())
@@ -0,0 +1,20 @@
1
+ """A package for the `PredictionWriterCallback` class and utilities."""
2
+
3
+ __all__ = [
4
+ "PredictionWriterCallback",
5
+ "create_write_strategy",
6
+ "WriteStrategy",
7
+ "WriteImage",
8
+ "CacheTiles",
9
+ "WriteTilesZarr",
10
+ "select_write_extension",
11
+ "select_write_func",
12
+ ]
13
+
14
+ from .prediction_writer_callback import PredictionWriterCallback
15
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
16
+ from .write_strategy_factory import (
17
+ create_write_strategy,
18
+ select_write_extension,
19
+ select_write_func,
20
+ )
@@ -0,0 +1,56 @@
1
+ """Module containing file path utilities for `WriteStrategy` to use."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
7
+
8
+
9
+ # TODO: move to datasets package ?
10
+ def get_sample_file_path(
11
+ dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int
12
+ ) -> Path:
13
+ """
14
+ Get the file path for a particular sample.
15
+
16
+ Parameters
17
+ ----------
18
+ dataset : IterableTiledPredDataset or IterablePredDataset
19
+ Dataset.
20
+ sample_id : int
21
+ Sample ID, the index of the file in the dataset `dataset`.
22
+
23
+ Returns
24
+ -------
25
+ Path
26
+ The file path corresponding to the sample with the ID `sample_id`.
27
+ """
28
+ return dataset.data_files[sample_id]
29
+
30
+
31
+ def create_write_file_path(
32
+ dirpath: Path, file_path: Path, write_extension: str
33
+ ) -> Path:
34
+ """
35
+ Create the file name for the output file.
36
+
37
+ Takes the original file path, changes the directory to `dirpath` and changes
38
+ the extension to `write_extension`.
39
+
40
+ Parameters
41
+ ----------
42
+ dirpath : pathlib.Path
43
+ The output directory to write file to.
44
+ file_path : pathlib.Path
45
+ The original file path.
46
+ write_extension : str
47
+ The extension that output files should have.
48
+
49
+ Returns
50
+ -------
51
+ Path
52
+ The output file path.
53
+ """
54
+ file_name = Path(file_path.stem).with_suffix(write_extension)
55
+ file_path = dirpath / file_name
56
+ return file_path
@@ -0,0 +1,233 @@
1
+ """Module containing `PredictionWriterCallback` class."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Sequence, Union
7
+
8
+ from pytorch_lightning import LightningModule, Trainer
9
+ from pytorch_lightning.callbacks import BasePredictionWriter
10
+ from torch.utils.data import DataLoader
11
+
12
+ from careamics.dataset import (
13
+ IterablePredDataset,
14
+ IterableTiledPredDataset,
15
+ )
16
+ from careamics.file_io import SupportedWriteType, WriteFunc
17
+ from careamics.utils import get_logger
18
+
19
+ from .write_strategy import WriteStrategy
20
+ from .write_strategy_factory import create_write_strategy
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ ValidPredDatasets = Union[IterablePredDataset, IterableTiledPredDataset]
25
+
26
+
27
+ class PredictionWriterCallback(BasePredictionWriter):
28
+ """
29
+ A PyTorch Lightning callback to save predictions.
30
+
31
+ Parameters
32
+ ----------
33
+ write_strategy : WriteStrategy
34
+ A strategy for writing predictions.
35
+ dirpath : Path or str, default="predictions"
36
+ The path to the directory where prediction outputs will be saved. If
37
+ `dirpath` is not absolute it is assumed to be relative to current working
38
+ directory.
39
+
40
+ Attributes
41
+ ----------
42
+ write_strategy : WriteStrategy
43
+ A strategy for writing predictions.
44
+ dirpath : pathlib.Path, default="predictions"
45
+ The path to the directory where prediction outputs will be saved. If
46
+ `dirpath` is not absolute it is assumed to be relative to current working
47
+ directory.
48
+ writing_predictions : bool
49
+ If writing predictions is turned on or off.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ write_strategy: WriteStrategy,
55
+ dirpath: Union[Path, str] = "predictions",
56
+ ):
57
+ """
58
+ A PyTorch Lightning callback to save predictions.
59
+
60
+ Parameters
61
+ ----------
62
+ write_strategy : WriteStrategy
63
+ A strategy for writing predictions.
64
+ dirpath : pathlib.Path or str, default="predictions"
65
+ The path to the directory where prediction outputs will be saved. If
66
+ `dirpath` is not absolute it is assumed to be relative to current working
67
+ directory.
68
+ """
69
+ super().__init__(write_interval="batch")
70
+
71
+ # Toggle for CAREamist to switch off saving if desired
72
+ self.writing_predictions: bool = True
73
+
74
+ self.write_strategy: WriteStrategy = write_strategy
75
+
76
+ # forward declaration
77
+ self.dirpath: Path
78
+ # attribute initialisation
79
+ self._init_dirpath(dirpath)
80
+
81
+ @classmethod
82
+ def from_write_func_params(
83
+ cls,
84
+ write_type: SupportedWriteType,
85
+ tiled: bool,
86
+ write_func: Optional[WriteFunc] = None,
87
+ write_extension: Optional[str] = None,
88
+ write_func_kwargs: Optional[dict[str, Any]] = None,
89
+ dirpath: Union[Path, str] = "predictions",
90
+ ) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
91
+ """
92
+ Initialize a `PredictionWriterCallback` from write function parameters.
93
+
94
+ This will automatically create a `WriteStrategy` to be passed to the
95
+ initialization of `PredictionWriterCallback`.
96
+
97
+ Parameters
98
+ ----------
99
+ write_type : {"tiff", "custom"}
100
+ The data type to save as, includes custom.
101
+ tiled : bool
102
+ Whether the prediction will be tiled or not.
103
+ write_func : WriteFunc, optional
104
+ If a known `write_type` is selected this argument is ignored. For a custom
105
+ `write_type` a function to save the data must be passed. See notes below.
106
+ write_extension : str, optional
107
+ If a known `write_type` is selected this argument is ignored. For a custom
108
+ `write_type` an extension to save the data with must be passed.
109
+ write_func_kwargs : dict of {{str: any}}, optional
110
+ Additional keyword arguments to be passed to the save function.
111
+ dirpath : pathlib.Path or str, default="predictions"
112
+ The path to the directory where prediction outputs will be saved. If
113
+ `dirpath` is not absolute it is assumed to be relative to current working
114
+ directory.
115
+
116
+ Returns
117
+ -------
118
+ PredictionWriterCallback
119
+ Callback for writing predictions.
120
+ """
121
+ write_strategy = create_write_strategy(
122
+ write_type=write_type,
123
+ tiled=tiled,
124
+ write_func=write_func,
125
+ write_extension=write_extension,
126
+ write_func_kwargs=write_func_kwargs,
127
+ )
128
+ return cls(write_strategy=write_strategy, dirpath=dirpath)
129
+
130
+ def _init_dirpath(self, dirpath):
131
+ """
132
+ Initialize directory path. Should only be called from `__init__`.
133
+
134
+ Parameters
135
+ ----------
136
+ dirpath : pathlib.Path
137
+ See `__init__` description.
138
+ """
139
+ dirpath = Path(dirpath)
140
+ if not dirpath.is_absolute():
141
+ dirpath = Path.cwd() / dirpath
142
+ logger.warning(
143
+ "Prediction output directory is not absolute, absolute path assumed to"
144
+ f"be '{dirpath}'"
145
+ )
146
+ self.dirpath = dirpath
147
+
148
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
149
+ """
150
+ Create the prediction output directory when predict begins.
151
+
152
+ Called when fit, validate, test, predict, or tune begins.
153
+
154
+ Parameters
155
+ ----------
156
+ trainer : Trainer
157
+ PyTorch Lightning trainer.
158
+ pl_module : LightningModule
159
+ PyTorch Lightning module.
160
+ stage : str
161
+ Stage of training e.g. 'predict', 'fit', 'validate'.
162
+ """
163
+ super().setup(trainer, pl_module, stage)
164
+ if stage == "predict":
165
+ # make prediction output directory
166
+ logger.info("Making prediction output directory.")
167
+ self.dirpath.mkdir(parents=True, exist_ok=True)
168
+
169
+ def write_on_batch_end(
170
+ self,
171
+ trainer: Trainer,
172
+ pl_module: LightningModule,
173
+ prediction: Any, # TODO: change to expected type
174
+ batch_indices: Optional[Sequence[int]],
175
+ batch: Any, # TODO: change to expected type
176
+ batch_idx: int,
177
+ dataloader_idx: int,
178
+ ) -> None:
179
+ """
180
+ Write predictions at the end of a batch.
181
+
182
+ The method of prediction is determined by the attribute `write_strategy`.
183
+
184
+ Parameters
185
+ ----------
186
+ trainer : Trainer
187
+ PyTorch Lightning trainer.
188
+ pl_module : LightningModule
189
+ PyTorch Lightning module.
190
+ prediction : Any
191
+ Prediction outputs of `batch`.
192
+ batch_indices : sequence of Any, optional
193
+ Batch indices.
194
+ batch : Any
195
+ Input batch.
196
+ batch_idx : int
197
+ Batch index.
198
+ dataloader_idx : int
199
+ Dataloader index.
200
+ """
201
+ # if writing prediction is turned off
202
+ if not self.writing_predictions:
203
+ return
204
+
205
+ dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
206
+ dataloader: DataLoader = (
207
+ dataloaders[dataloader_idx]
208
+ if isinstance(dataloaders, list)
209
+ else dataloaders
210
+ )
211
+ dataset: ValidPredDatasets = dataloader.dataset
212
+ if not (
213
+ isinstance(dataset, IterablePredDataset)
214
+ or isinstance(dataset, IterableTiledPredDataset)
215
+ ):
216
+ # Note: Error will be raised before here from the source type
217
+ # This is for extra redundancy of errors.
218
+ raise TypeError(
219
+ "Prediction dataset has to be `IterableTiledPredDataset` or "
220
+ "`IterablePredDataset`. Cannot be `InMemoryPredDataset` because "
221
+ "filenames are taken from the original file."
222
+ )
223
+
224
+ self.write_strategy.write_batch(
225
+ trainer=trainer,
226
+ pl_module=pl_module,
227
+ prediction=prediction,
228
+ batch_indices=batch_indices,
229
+ batch=batch,
230
+ batch_idx=batch_idx,
231
+ dataloader_idx=dataloader_idx,
232
+ dirpath=self.dirpath,
233
+ )