careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,69 +1,37 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torch.utils.data.dataloader import default_collate
|
|
10
10
|
|
|
11
11
|
from careamics.config import InferenceConfig
|
|
12
12
|
from careamics.config.support import SupportedData
|
|
13
|
-
from careamics.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from careamics.dataset.in_memory_dataset import (
|
|
19
|
-
InMemoryPredictionDataset,
|
|
20
|
-
)
|
|
21
|
-
from careamics.dataset.iterable_dataset import (
|
|
22
|
-
IterablePredictionDataset,
|
|
13
|
+
from careamics.dataset import (
|
|
14
|
+
InMemoryPredDataset,
|
|
15
|
+
InMemoryTiledPredDataset,
|
|
16
|
+
IterablePredDataset,
|
|
17
|
+
IterableTiledPredDataset,
|
|
23
18
|
)
|
|
19
|
+
from careamics.dataset.dataset_utils import list_files
|
|
20
|
+
from careamics.dataset.tiling.collate_tiles import collate_tiles
|
|
21
|
+
from careamics.file_io.read import get_read_func
|
|
24
22
|
from careamics.utils import get_logger
|
|
25
23
|
|
|
26
|
-
PredictDatasetType = Union[
|
|
24
|
+
PredictDatasetType = Union[
|
|
25
|
+
InMemoryPredDataset,
|
|
26
|
+
InMemoryTiledPredDataset,
|
|
27
|
+
IterablePredDataset,
|
|
28
|
+
IterableTiledPredDataset,
|
|
29
|
+
]
|
|
27
30
|
|
|
28
31
|
logger = get_logger(__name__)
|
|
29
32
|
|
|
30
33
|
|
|
31
|
-
|
|
32
|
-
"""
|
|
33
|
-
Collate tiles received from CAREamics prediction dataloader.
|
|
34
|
-
|
|
35
|
-
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
36
|
-
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
37
|
-
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
38
|
-
stitch coordinates.
|
|
39
|
-
|
|
40
|
-
Parameters
|
|
41
|
-
----------
|
|
42
|
-
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
43
|
-
Batch of tiles.
|
|
44
|
-
|
|
45
|
-
Returns
|
|
46
|
-
-------
|
|
47
|
-
Any
|
|
48
|
-
Collated batch.
|
|
49
|
-
"""
|
|
50
|
-
first_tile_info: TileInformation = batch[0][1]
|
|
51
|
-
# if not tiled, then return arrays
|
|
52
|
-
if not first_tile_info.tiled:
|
|
53
|
-
arrays, _ = zip(*batch)
|
|
54
|
-
|
|
55
|
-
return default_collate(arrays)
|
|
56
|
-
# else we explicit the last_tile flag and coordinates
|
|
57
|
-
else:
|
|
58
|
-
new_batch = [
|
|
59
|
-
(tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
|
|
60
|
-
for tile, t in batch
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
return default_collate(new_batch)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class CAREamicsPredictData(L.LightningDataModule):
|
|
34
|
+
class PredictDataModule(L.LightningDataModule):
|
|
67
35
|
"""
|
|
68
36
|
CAREamics Lightning prediction data module.
|
|
69
37
|
|
|
@@ -82,9 +50,9 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
82
50
|
----------
|
|
83
51
|
pred_config : InferenceModel
|
|
84
52
|
Pydantic model for CAREamics prediction configuration.
|
|
85
|
-
pred_data :
|
|
53
|
+
pred_data : pathlib.Path or str or numpy.ndarray
|
|
86
54
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
87
|
-
read_source_func :
|
|
55
|
+
read_source_func : Callable, optional
|
|
88
56
|
Function to read custom types, by default None.
|
|
89
57
|
extension_filter : str, optional
|
|
90
58
|
Filter to filter file extensions for custom types, by default "".
|
|
@@ -95,7 +63,7 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
95
63
|
def __init__(
|
|
96
64
|
self,
|
|
97
65
|
pred_config: InferenceConfig,
|
|
98
|
-
pred_data: Union[Path, str,
|
|
66
|
+
pred_data: Union[Path, str, NDArray],
|
|
99
67
|
read_source_func: Optional[Callable] = None,
|
|
100
68
|
extension_filter: str = "",
|
|
101
69
|
dataloader_params: Optional[dict] = None,
|
|
@@ -118,9 +86,9 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
118
86
|
----------
|
|
119
87
|
pred_config : InferenceModel
|
|
120
88
|
Pydantic model for CAREamics prediction configuration.
|
|
121
|
-
pred_data :
|
|
89
|
+
pred_data : pathlib.Path or str or numpy.ndarray
|
|
122
90
|
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
123
|
-
read_source_func :
|
|
91
|
+
read_source_func : Callable, optional
|
|
124
92
|
Function to read custom types, by default None.
|
|
125
93
|
extension_filter : str, optional
|
|
126
94
|
Filter to filter file extensions for custom types, by default "".
|
|
@@ -182,6 +150,9 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
182
150
|
self.tile_size = pred_config.tile_size
|
|
183
151
|
self.tile_overlap = pred_config.tile_overlap
|
|
184
152
|
|
|
153
|
+
# check if it is tiled
|
|
154
|
+
self.tiled = self.tile_size is not None and self.tile_overlap is not None
|
|
155
|
+
|
|
185
156
|
# read source function
|
|
186
157
|
if pred_config.data_type == SupportedData.CUSTOM:
|
|
187
158
|
# mypy check
|
|
@@ -212,17 +183,29 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
212
183
|
"""
|
|
213
184
|
# if numpy array
|
|
214
185
|
if self.data_type == SupportedData.ARRAY:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
186
|
+
if self.tiled:
|
|
187
|
+
self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
|
|
188
|
+
prediction_config=self.prediction_config,
|
|
189
|
+
inputs=self.pred_data,
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
self.predict_dataset = InMemoryPredDataset(
|
|
193
|
+
prediction_config=self.prediction_config,
|
|
194
|
+
inputs=self.pred_data,
|
|
195
|
+
)
|
|
220
196
|
else:
|
|
221
|
-
self.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
197
|
+
if self.tiled:
|
|
198
|
+
self.predict_dataset = IterableTiledPredDataset(
|
|
199
|
+
prediction_config=self.prediction_config,
|
|
200
|
+
src_files=self.pred_files,
|
|
201
|
+
read_source_func=self.read_source_func,
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
self.predict_dataset = IterablePredDataset(
|
|
205
|
+
prediction_config=self.prediction_config,
|
|
206
|
+
src_files=self.pred_files,
|
|
207
|
+
read_source_func=self.read_source_func,
|
|
208
|
+
)
|
|
226
209
|
|
|
227
210
|
def predict_dataloader(self) -> DataLoader:
|
|
228
211
|
"""
|
|
@@ -236,35 +219,38 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
236
219
|
return DataLoader(
|
|
237
220
|
self.predict_dataset,
|
|
238
221
|
batch_size=self.batch_size,
|
|
239
|
-
collate_fn=
|
|
222
|
+
collate_fn=collate_tiles if self.tiled else None,
|
|
240
223
|
**self.dataloader_params,
|
|
241
|
-
)
|
|
242
|
-
|
|
224
|
+
)
|
|
243
225
|
|
|
244
|
-
class PredictDataWrapper(CAREamicsPredictData):
|
|
245
|
-
"""
|
|
246
|
-
Wrapper around the CAREamics inference Lightning data module.
|
|
247
226
|
|
|
248
|
-
|
|
227
|
+
def create_predict_datamodule(
|
|
228
|
+
pred_data: Union[str, Path, NDArray],
|
|
229
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
230
|
+
axes: str,
|
|
231
|
+
image_means: list[float],
|
|
232
|
+
image_stds: list[float],
|
|
233
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
234
|
+
tile_overlap: Optional[tuple[int, ...]] = None,
|
|
235
|
+
batch_size: int = 1,
|
|
236
|
+
tta_transforms: bool = True,
|
|
237
|
+
read_source_func: Optional[Callable] = None,
|
|
238
|
+
extension_filter: str = "",
|
|
239
|
+
dataloader_params: Optional[dict] = None,
|
|
240
|
+
) -> PredictDataModule:
|
|
241
|
+
"""Create a CAREamics prediction Lightning datamodule.
|
|
242
|
+
|
|
243
|
+
This function is used to explicitely pass the parameters usually contained in an
|
|
249
244
|
`inference_model` configuration.
|
|
250
245
|
|
|
251
246
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
252
247
|
parameters passed to the datamodule are consistent with the model's requirements
|
|
253
|
-
and are coherent.
|
|
248
|
+
and are coherent. This can be done by creating a `Configuration` object beforehand
|
|
249
|
+
and passing its parameters to the different Lightning modules.
|
|
254
250
|
|
|
255
251
|
The data module can be used with Path, str or numpy arrays. To use array data, set
|
|
256
252
|
`data_type` to `array` and pass a numpy array to `train_data`.
|
|
257
253
|
|
|
258
|
-
The default transformations applied to the images are defined in
|
|
259
|
-
`careamics.config.inference_model`. To use different transformations, pass a list
|
|
260
|
-
of transforms. See examples
|
|
261
|
-
for more details.
|
|
262
|
-
|
|
263
|
-
The `mean` and `std` parameters are only used if Normalization is defined either
|
|
264
|
-
in the default transformations or in the `transforms` parameter. If you pass a
|
|
265
|
-
`Normalization` transform in a list as `transforms`, then the mean and std
|
|
266
|
-
parameters will be overwritten by those passed to this method.
|
|
267
|
-
|
|
268
254
|
By default, CAREamics only supports types defined in
|
|
269
255
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
270
256
|
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
@@ -275,117 +261,73 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
275
261
|
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
276
262
|
parameter.
|
|
277
263
|
|
|
278
|
-
Note that if you are using a UNet model and tiling, the tile size must be
|
|
279
|
-
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
280
|
-
avoids artefacts arising from the broken shift invariance induced by the
|
|
281
|
-
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
282
|
-
happen in the Z dimension, consider padding your image.
|
|
283
|
-
|
|
284
264
|
Parameters
|
|
285
265
|
----------
|
|
286
|
-
pred_data :
|
|
266
|
+
pred_data : str or pathlib.Path or numpy.ndarray
|
|
287
267
|
Prediction data.
|
|
288
|
-
data_type :
|
|
268
|
+
data_type : {"array", "tiff", "custom"}
|
|
289
269
|
Data type, see `SupportedData` for available options.
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
tile_size :
|
|
270
|
+
axes : str
|
|
271
|
+
Axes of the data, choosen among SCZYX.
|
|
272
|
+
image_means : list of float
|
|
273
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
274
|
+
image_stds : list of float
|
|
275
|
+
Std values for normalization, only used if Normalization is defined.
|
|
276
|
+
tile_size : tuple of int, optional
|
|
297
277
|
Tile size, 2D or 3D tile size.
|
|
298
|
-
tile_overlap :
|
|
278
|
+
tile_overlap : tuple of int, optional
|
|
299
279
|
Tile overlap, 2D or 3D tile overlap.
|
|
300
|
-
axes : str
|
|
301
|
-
Axes of the data, choosen amongst SCZYX.
|
|
302
280
|
batch_size : int
|
|
303
281
|
Batch size.
|
|
304
282
|
tta_transforms : bool, optional
|
|
305
283
|
Use test time augmentation, by default True.
|
|
306
|
-
read_source_func :
|
|
284
|
+
read_source_func : Callable, optional
|
|
307
285
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
308
286
|
default None.
|
|
309
287
|
extension_filter : str, optional
|
|
310
288
|
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
311
289
|
dataloader_params : dict, optional
|
|
312
290
|
Pytorch dataloader parameters, by default {}.
|
|
313
|
-
"""
|
|
314
291
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
mean: float,
|
|
320
|
-
std: float,
|
|
321
|
-
tile_size: Optional[Tuple[int, ...]] = None,
|
|
322
|
-
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
323
|
-
axes: str = "YX",
|
|
324
|
-
batch_size: int = 1,
|
|
325
|
-
tta_transforms: bool = True,
|
|
326
|
-
read_source_func: Optional[Callable] = None,
|
|
327
|
-
extension_filter: str = "",
|
|
328
|
-
dataloader_params: Optional[dict] = None,
|
|
329
|
-
) -> None:
|
|
330
|
-
"""
|
|
331
|
-
Constructor.
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
PredictDataModule
|
|
295
|
+
CAREamics prediction datamodule.
|
|
332
296
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
tile_overlap
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
"mean": mean,
|
|
371
|
-
"std": std,
|
|
372
|
-
"tta": tta_transforms,
|
|
373
|
-
"batch_size": batch_size,
|
|
374
|
-
"transforms": [],
|
|
375
|
-
}
|
|
376
|
-
|
|
377
|
-
# validate configuration
|
|
378
|
-
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
379
|
-
|
|
380
|
-
# sanity check on the dataloader parameters
|
|
381
|
-
if "batch_size" in dataloader_params:
|
|
382
|
-
# remove it
|
|
383
|
-
del dataloader_params["batch_size"]
|
|
384
|
-
|
|
385
|
-
super().__init__(
|
|
386
|
-
pred_config=self.prediction_config,
|
|
387
|
-
pred_data=pred_data,
|
|
388
|
-
read_source_func=read_source_func,
|
|
389
|
-
extension_filter=extension_filter,
|
|
390
|
-
dataloader_params=dataloader_params,
|
|
391
|
-
)
|
|
297
|
+
Notes
|
|
298
|
+
-----
|
|
299
|
+
If you are using a UNet model and tiling, the tile size must be
|
|
300
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
301
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
302
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
303
|
+
happen in the Z dimension, consider padding your image.
|
|
304
|
+
"""
|
|
305
|
+
if dataloader_params is None:
|
|
306
|
+
dataloader_params = {}
|
|
307
|
+
|
|
308
|
+
prediction_dict: dict[str, Any] = {
|
|
309
|
+
"data_type": data_type,
|
|
310
|
+
"tile_size": tile_size,
|
|
311
|
+
"tile_overlap": tile_overlap,
|
|
312
|
+
"axes": axes,
|
|
313
|
+
"image_means": image_means,
|
|
314
|
+
"image_stds": image_stds,
|
|
315
|
+
"tta_transforms": tta_transforms,
|
|
316
|
+
"batch_size": batch_size,
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
# validate configuration
|
|
320
|
+
prediction_config = InferenceConfig(**prediction_dict)
|
|
321
|
+
|
|
322
|
+
# sanity check on the dataloader parameters
|
|
323
|
+
if "batch_size" in dataloader_params:
|
|
324
|
+
# remove it
|
|
325
|
+
del dataloader_params["batch_size"]
|
|
326
|
+
|
|
327
|
+
return PredictDataModule(
|
|
328
|
+
pred_config=prediction_config,
|
|
329
|
+
pred_data=pred_data,
|
|
330
|
+
read_source_func=read_source_func,
|
|
331
|
+
extension_filter=extension_filter,
|
|
332
|
+
dataloader_params=dataloader_params,
|
|
333
|
+
)
|