careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +55 -61
- careamics/cli/conf.py +24 -9
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +53 -18
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +15 -11
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +892 -78
- careamics/config/data/data_model.py +7 -14
- careamics/config/data/ng_data_model.py +8 -15
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -11
- careamics/config/likelihood_model.py +4 -4
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +30 -7
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +8 -38
- careamics/config/transformations/normalize_model.py +3 -4
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +166 -68
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +16 -9
- careamics/lightning/train_data_module.py +29 -18
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +94 -9
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +12 -8
- careamics/models/layers.py +5 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +4 -6
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
careamics/dataset_ng/dataset.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generic, Literal, NamedTuple,
|
|
4
|
+
from typing import Any, Generic, Literal, NamedTuple, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from numpy.typing import NDArray
|
|
@@ -51,7 +51,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
51
51
|
data_config: NGDataConfig,
|
|
52
52
|
mode: Mode,
|
|
53
53
|
input_extractor: PatchExtractor[GenericImageStack],
|
|
54
|
-
target_extractor:
|
|
54
|
+
target_extractor: PatchExtractor[GenericImageStack] | None = None,
|
|
55
55
|
):
|
|
56
56
|
self.config = data_config
|
|
57
57
|
self.mode = mode
|
|
@@ -115,7 +115,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
115
115
|
|
|
116
116
|
return patching_strategy
|
|
117
117
|
|
|
118
|
-
def _initialize_transforms(self) ->
|
|
118
|
+
def _initialize_transforms(self) -> Compose | None:
|
|
119
119
|
normalize = NormalizeModel(
|
|
120
120
|
image_means=self.input_stats.means,
|
|
121
121
|
image_stds=self.input_stats.stds,
|
careamics/dataset_ng/factory.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
6
|
from numpy.typing import NDArray
|
|
7
7
|
from typing_extensions import ParamSpec
|
|
@@ -48,8 +48,8 @@ class DatasetType(Enum):
|
|
|
48
48
|
def determine_dataset_type(
|
|
49
49
|
data_type: SupportedData,
|
|
50
50
|
in_memory: bool,
|
|
51
|
-
read_func:
|
|
52
|
-
image_stack_loader:
|
|
51
|
+
read_func: ReadFunc | None = None,
|
|
52
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
53
53
|
) -> DatasetType:
|
|
54
54
|
"""Determine what the dataset type should be based on the input arguments.
|
|
55
55
|
|
|
@@ -121,10 +121,10 @@ def create_dataset(
|
|
|
121
121
|
inputs: Any,
|
|
122
122
|
targets: Any,
|
|
123
123
|
in_memory: bool,
|
|
124
|
-
read_func:
|
|
125
|
-
read_kwargs:
|
|
126
|
-
image_stack_loader:
|
|
127
|
-
image_stack_loader_kwargs:
|
|
124
|
+
read_func: ReadFunc | None = None,
|
|
125
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
126
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
127
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
128
128
|
) -> CareamicsDataset[ImageStack]:
|
|
129
129
|
"""
|
|
130
130
|
Convenience function to create the CAREamicsDataset.
|
|
@@ -201,7 +201,7 @@ def create_array_dataset(
|
|
|
201
201
|
config: NGDataConfig,
|
|
202
202
|
mode: Mode,
|
|
203
203
|
inputs: Sequence[NDArray[Any]],
|
|
204
|
-
targets:
|
|
204
|
+
targets: Sequence[NDArray[Any]] | None,
|
|
205
205
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
206
206
|
"""
|
|
207
207
|
Create a CAREamicsDataset from array data.
|
|
@@ -223,7 +223,7 @@ def create_array_dataset(
|
|
|
223
223
|
A CAREamicsDataset.
|
|
224
224
|
"""
|
|
225
225
|
input_extractor = create_array_extractor(source=inputs, axes=config.axes)
|
|
226
|
-
target_extractor:
|
|
226
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
227
227
|
if targets is not None:
|
|
228
228
|
target_extractor = create_array_extractor(source=targets, axes=config.axes)
|
|
229
229
|
else:
|
|
@@ -235,7 +235,7 @@ def create_tiff_dataset(
|
|
|
235
235
|
config: NGDataConfig,
|
|
236
236
|
mode: Mode,
|
|
237
237
|
inputs: Sequence[Path],
|
|
238
|
-
targets:
|
|
238
|
+
targets: Sequence[Path] | None,
|
|
239
239
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
240
240
|
"""
|
|
241
241
|
Create a CAREamicsDataset from tiff files that will be all loaded into memory.
|
|
@@ -260,7 +260,7 @@ def create_tiff_dataset(
|
|
|
260
260
|
source=inputs,
|
|
261
261
|
axes=config.axes,
|
|
262
262
|
)
|
|
263
|
-
target_extractor:
|
|
263
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
264
264
|
if targets is not None:
|
|
265
265
|
target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
|
|
266
266
|
else:
|
|
@@ -273,7 +273,7 @@ def create_czi_dataset(
|
|
|
273
273
|
config: NGDataConfig,
|
|
274
274
|
mode: Mode,
|
|
275
275
|
inputs: Sequence[Path],
|
|
276
|
-
targets:
|
|
276
|
+
targets: Sequence[Path] | None,
|
|
277
277
|
) -> CareamicsDataset[CziImageStack]:
|
|
278
278
|
"""
|
|
279
279
|
Create a dataset from CZI files.
|
|
@@ -296,7 +296,7 @@ def create_czi_dataset(
|
|
|
296
296
|
"""
|
|
297
297
|
|
|
298
298
|
input_extractor = create_czi_extractor(source=inputs, axes=config.axes)
|
|
299
|
-
target_extractor:
|
|
299
|
+
target_extractor: PatchExtractor[CziImageStack] | None
|
|
300
300
|
if targets is not None:
|
|
301
301
|
target_extractor = create_czi_extractor(source=targets, axes=config.axes)
|
|
302
302
|
else:
|
|
@@ -309,7 +309,7 @@ def create_ome_zarr_dataset(
|
|
|
309
309
|
config: NGDataConfig,
|
|
310
310
|
mode: Mode,
|
|
311
311
|
inputs: Sequence[Path],
|
|
312
|
-
targets:
|
|
312
|
+
targets: Sequence[Path] | None,
|
|
313
313
|
) -> CareamicsDataset[ZarrImageStack]:
|
|
314
314
|
"""
|
|
315
315
|
Create a dataset from OME ZARR files.
|
|
@@ -332,7 +332,7 @@ def create_ome_zarr_dataset(
|
|
|
332
332
|
"""
|
|
333
333
|
|
|
334
334
|
input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
|
|
335
|
-
target_extractor:
|
|
335
|
+
target_extractor: PatchExtractor[ZarrImageStack] | None
|
|
336
336
|
if targets is not None:
|
|
337
337
|
target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
|
|
338
338
|
else:
|
|
@@ -345,7 +345,7 @@ def create_custom_file_dataset(
|
|
|
345
345
|
config: NGDataConfig,
|
|
346
346
|
mode: Mode,
|
|
347
347
|
inputs: Sequence[Path],
|
|
348
|
-
targets:
|
|
348
|
+
targets: Sequence[Path] | None,
|
|
349
349
|
*,
|
|
350
350
|
read_func: ReadFunc,
|
|
351
351
|
read_kwargs: dict[str, Any],
|
|
@@ -378,7 +378,7 @@ def create_custom_file_dataset(
|
|
|
378
378
|
input_extractor = create_custom_file_extractor(
|
|
379
379
|
source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
|
|
380
380
|
)
|
|
381
|
-
target_extractor:
|
|
381
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
382
382
|
if targets is not None:
|
|
383
383
|
target_extractor = create_custom_file_extractor(
|
|
384
384
|
source=targets,
|
|
@@ -396,7 +396,7 @@ def create_custom_image_stack_dataset(
|
|
|
396
396
|
config: NGDataConfig,
|
|
397
397
|
mode: Mode,
|
|
398
398
|
inputs: Any,
|
|
399
|
-
targets:
|
|
399
|
+
targets: Any | None,
|
|
400
400
|
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
401
401
|
*args: P.args,
|
|
402
402
|
**kwargs: P.kwargs,
|
|
@@ -436,7 +436,7 @@ def create_custom_image_stack_dataset(
|
|
|
436
436
|
*args,
|
|
437
437
|
**kwargs,
|
|
438
438
|
)
|
|
439
|
-
target_extractor:
|
|
439
|
+
target_extractor: PatchExtractor[GenericImageStack] | None
|
|
440
440
|
if targets is not None:
|
|
441
441
|
target_extractor = create_custom_image_stack_extractor(
|
|
442
442
|
targets,
|
|
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import zarr
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
|
-
from zarr.storage import
|
|
10
|
+
from zarr.storage import FsspecStore
|
|
11
11
|
|
|
12
12
|
from careamics.config import DataConfig
|
|
13
13
|
from careamics.config.support import SupportedData
|
|
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
|
20
20
|
|
|
21
21
|
# %%
|
|
22
22
|
def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
|
|
23
|
-
store =
|
|
23
|
+
store = FsspecStore.from_url(url=file_path.resolve())
|
|
24
24
|
# create array
|
|
25
25
|
array = zarr.create(
|
|
26
26
|
store=store,
|
|
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
|
|
|
61
61
|
# ### Make sure file exists
|
|
62
62
|
|
|
63
63
|
# %%
|
|
64
|
-
store =
|
|
64
|
+
store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
|
|
65
65
|
|
|
66
66
|
# %%
|
|
67
67
|
list(store.keys())
|
|
@@ -72,7 +72,7 @@ list(store.keys())
|
|
|
72
72
|
|
|
73
73
|
# %%
|
|
74
74
|
class ZarrSource(TypedDict):
|
|
75
|
-
store:
|
|
75
|
+
store: FsspecStore
|
|
76
76
|
data_paths: Sequence[str]
|
|
77
77
|
|
|
78
78
|
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Literal, Union
|
|
3
|
+
from typing import Any, Literal, Self, Union
|
|
4
4
|
|
|
5
5
|
from numpy.typing import DTypeLike, NDArray
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from careamics.dataset.dataset_utils import reshape_array
|
|
9
8
|
from careamics.file_io.read import ReadFunc, read_tiff
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
3
|
+
from typing import Self, Union
|
|
4
4
|
|
|
5
|
+
import validators
|
|
5
6
|
import zarr
|
|
6
|
-
import zarr.storage
|
|
7
7
|
from numpy.typing import NDArray
|
|
8
|
-
from
|
|
8
|
+
from zarr.storage import FsspecStore, LocalStore
|
|
9
9
|
|
|
10
10
|
from careamics.dataset.dataset_utils import reshape_array
|
|
11
11
|
|
|
@@ -15,9 +15,10 @@ class ZarrImageStack:
|
|
|
15
15
|
A class for extracting patches from an image stack that is stored as a zarr array.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
# TODO:
|
|
19
|
-
#
|
|
20
|
-
|
|
18
|
+
# TODO: We should keep store type narrow
|
|
19
|
+
# - in zarr v3, does zarr.storage.Store exists and has the path attribute?
|
|
20
|
+
# - can we declare a narrow type rather than a union?
|
|
21
|
+
def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
|
|
21
22
|
self._store = store
|
|
22
23
|
self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
|
|
23
24
|
# TODO: validate axes
|
|
@@ -46,8 +47,33 @@ class ZarrImageStack:
|
|
|
46
47
|
Assumes the path only contains 1 image.
|
|
47
48
|
|
|
48
49
|
Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
path : Union[Path, str]
|
|
54
|
+
Path to the root of the OME-Zarr, local file or url.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
ZarrImageStack
|
|
59
|
+
Initialised ZarrImageStack.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If the path does not exist or is not a valid URL.
|
|
65
|
+
ValueError
|
|
66
|
+
If the OME-Zarr at the path does not contain the attribute 'multiscales'.
|
|
49
67
|
"""
|
|
50
|
-
|
|
68
|
+
if Path(path).is_file():
|
|
69
|
+
store = zarr.storage.LocalStore(root=Path(path).resolve())
|
|
70
|
+
elif validators.url(path):
|
|
71
|
+
store = zarr.storage.FsspecStore.from_url(url=path)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Path '{path}' is neither an existing file nor a valid URL."
|
|
75
|
+
)
|
|
76
|
+
|
|
51
77
|
group = zarr.open_group(store=store, mode="r")
|
|
52
78
|
if "multiscales" not in group.attrs:
|
|
53
79
|
raise ValueError(
|
|
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
38
38
|
|
|
39
39
|
>>> from typing import TypedDict
|
|
40
40
|
|
|
41
|
-
>>> from zarr.storage import
|
|
41
|
+
>>> from zarr.storage import FsspecStore
|
|
42
42
|
|
|
43
43
|
>>> from careamics.config import DataConfig
|
|
44
44
|
>>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
|
|
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
46
46
|
>>> # Define a zarr source
|
|
47
47
|
>>> # It encompasses multiple arguments that determine what data will be loaded
|
|
48
48
|
>>> class ZarrSource(TypedDict):
|
|
49
|
-
... store:
|
|
49
|
+
... store: FsspecStore
|
|
50
50
|
... data_paths: Sequence[str]
|
|
51
51
|
|
|
52
52
|
>>> def custom_image_stack_loader(
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""A module for random patching strategies."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
@@ -31,7 +30,7 @@ class RandomPatchingStrategy:
|
|
|
31
30
|
self,
|
|
32
31
|
data_shapes: Sequence[Sequence[int]],
|
|
33
32
|
patch_size: Sequence[int],
|
|
34
|
-
seed:
|
|
33
|
+
seed: int | None = None,
|
|
35
34
|
):
|
|
36
35
|
"""
|
|
37
36
|
A patching strategy for sampling random patches.
|
|
@@ -193,7 +192,7 @@ class FixedRandomPatchingStrategy:
|
|
|
193
192
|
self,
|
|
194
193
|
data_shapes: Sequence[Sequence[int]],
|
|
195
194
|
patch_size: Sequence[int],
|
|
196
|
-
seed:
|
|
195
|
+
seed: int | None = None,
|
|
197
196
|
):
|
|
198
197
|
"""A patching strategy for sampling random patches.
|
|
199
198
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
from typing_extensions import ParamSpec
|
|
@@ -18,7 +17,7 @@ class SequentialPatchingStrategy:
|
|
|
18
17
|
self,
|
|
19
18
|
data_shapes: Sequence[Sequence[int]],
|
|
20
19
|
patch_size: Sequence[int],
|
|
21
|
-
overlaps:
|
|
20
|
+
overlaps: Sequence[int] | None = None,
|
|
22
21
|
):
|
|
23
22
|
self.data_shapes = data_shapes
|
|
24
23
|
self.patch_size = patch_size
|
careamics/lightning/__init__.py
CHANGED
|
@@ -1,18 +1,32 @@
|
|
|
1
1
|
"""CAREamics PyTorch Lightning modules."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"FCNModule",
|
|
5
6
|
"HyperParametersCallback",
|
|
7
|
+
"MicroSplitDataModule",
|
|
6
8
|
"PredictDataModule",
|
|
7
9
|
"ProgressBarCallback",
|
|
8
10
|
"TrainDataModule",
|
|
9
11
|
"VAEModule",
|
|
10
12
|
"create_careamics_module",
|
|
13
|
+
"create_microsplit_predict_datamodule",
|
|
14
|
+
"create_microsplit_train_datamodule",
|
|
11
15
|
"create_predict_datamodule",
|
|
12
16
|
"create_train_datamodule",
|
|
17
|
+
"create_unet_based_module",
|
|
18
|
+
"create_vae_based_module",
|
|
13
19
|
]
|
|
14
20
|
|
|
15
|
-
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
21
|
+
from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
|
|
16
22
|
from .lightning_module import FCNModule, VAEModule, create_careamics_module
|
|
23
|
+
from .microsplit_data_module import (
|
|
24
|
+
MicroSplitDataModule,
|
|
25
|
+
create_microsplit_predict_datamodule,
|
|
26
|
+
create_microsplit_train_datamodule,
|
|
27
|
+
)
|
|
17
28
|
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
18
|
-
from .train_data_module import
|
|
29
|
+
from .train_data_module import (
|
|
30
|
+
TrainDataModule,
|
|
31
|
+
create_train_datamodule,
|
|
32
|
+
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Callbacks module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"HyperParametersCallback",
|
|
5
6
|
"PredictionWriterCallback",
|
|
6
7
|
"ProgressBarCallback",
|
|
7
8
|
]
|
|
8
9
|
|
|
10
|
+
from .data_stats_callback import DataStatsCallback
|
|
9
11
|
from .hyperparameters_callback import HyperParametersCallback
|
|
10
12
|
from .prediction_writer_callback import PredictionWriterCallback
|
|
11
13
|
from .progress_bar_callback import ProgressBarCallback
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Data statistics callback."""
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DataStatsCallback(Callback):
|
|
8
|
+
"""Callback to update model's data statistics from datamodule.
|
|
9
|
+
|
|
10
|
+
This callback ensures that the model has access to the data statistics (mean and std)
|
|
11
|
+
calculated by the datamodule before training starts.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
|
|
15
|
+
"""Called when trainer is setting up."""
|
|
16
|
+
if stage == "fit":
|
|
17
|
+
# Get data statistics from datamodule
|
|
18
|
+
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
|
|
19
|
+
|
|
20
|
+
# Set data statistics in the model's likelihood module
|
|
21
|
+
module.noise_model_likelihood.set_data_stats(
|
|
22
|
+
data_mean=data_mean["target"], data_std=data_std["target"]
|
|
23
|
+
)
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Union
|
|
8
8
|
|
|
9
9
|
from pytorch_lightning import LightningModule, Trainer
|
|
10
10
|
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
@@ -84,9 +84,9 @@ class PredictionWriterCallback(BasePredictionWriter):
|
|
|
84
84
|
cls,
|
|
85
85
|
write_type: SupportedWriteType,
|
|
86
86
|
tiled: bool,
|
|
87
|
-
write_func:
|
|
88
|
-
write_extension:
|
|
89
|
-
write_func_kwargs:
|
|
87
|
+
write_func: WriteFunc | None = None,
|
|
88
|
+
write_extension: str | None = None,
|
|
89
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
90
90
|
dirpath: Union[Path, str] = "predictions",
|
|
91
91
|
) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
|
|
92
92
|
"""
|
|
@@ -172,7 +172,7 @@ class PredictionWriterCallback(BasePredictionWriter):
|
|
|
172
172
|
trainer: Trainer,
|
|
173
173
|
pl_module: LightningModule,
|
|
174
174
|
prediction: Any, # TODO: change to expected type
|
|
175
|
-
batch_indices:
|
|
175
|
+
batch_indices: Sequence[int] | None,
|
|
176
176
|
batch: Any, # TODO: change to expected type
|
|
177
177
|
batch_idx: int,
|
|
178
178
|
dataloader_idx: int,
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Protocol, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from numpy.typing import NDArray
|
|
@@ -25,7 +25,7 @@ class WriteStrategy(Protocol):
|
|
|
25
25
|
trainer: Trainer,
|
|
26
26
|
pl_module: LightningModule,
|
|
27
27
|
prediction: Any, # TODO: change to expected type
|
|
28
|
-
batch_indices:
|
|
28
|
+
batch_indices: Sequence[int] | None,
|
|
29
29
|
batch: Any, # TODO: change to expected type
|
|
30
30
|
batch_idx: int,
|
|
31
31
|
dataloader_idx: int,
|
|
@@ -133,7 +133,7 @@ class CacheTiles(WriteStrategy):
|
|
|
133
133
|
trainer: Trainer,
|
|
134
134
|
pl_module: LightningModule,
|
|
135
135
|
prediction: tuple[NDArray, list[TileInformation]],
|
|
136
|
-
batch_indices:
|
|
136
|
+
batch_indices: Sequence[int] | None,
|
|
137
137
|
batch: tuple[NDArray, list[TileInformation]],
|
|
138
138
|
batch_idx: int,
|
|
139
139
|
dataloader_idx: int,
|
|
@@ -259,7 +259,7 @@ class WriteTilesZarr(WriteStrategy):
|
|
|
259
259
|
trainer: Trainer,
|
|
260
260
|
pl_module: LightningModule,
|
|
261
261
|
prediction: Any,
|
|
262
|
-
batch_indices:
|
|
262
|
+
batch_indices: Sequence[int] | None,
|
|
263
263
|
batch: Any,
|
|
264
264
|
batch_idx: int,
|
|
265
265
|
dataloader_idx: int,
|
|
@@ -346,7 +346,7 @@ class WriteImage(WriteStrategy):
|
|
|
346
346
|
trainer: Trainer,
|
|
347
347
|
pl_module: LightningModule,
|
|
348
348
|
prediction: NDArray,
|
|
349
|
-
batch_indices:
|
|
349
|
+
batch_indices: Sequence[int] | None,
|
|
350
350
|
batch: NDArray,
|
|
351
351
|
batch_idx: int,
|
|
352
352
|
dataloader_idx: int,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing convenience function to create `WriteStrategy`."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from careamics.config.support import SupportedData
|
|
6
6
|
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
@@ -11,9 +11,9 @@ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
|
|
|
11
11
|
def create_write_strategy(
|
|
12
12
|
write_type: SupportedWriteType,
|
|
13
13
|
tiled: bool,
|
|
14
|
-
write_func:
|
|
15
|
-
write_extension:
|
|
16
|
-
write_func_kwargs:
|
|
14
|
+
write_func: WriteFunc | None = None,
|
|
15
|
+
write_extension: str | None = None,
|
|
16
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
17
17
|
) -> WriteStrategy:
|
|
18
18
|
"""
|
|
19
19
|
Create a write strategy from convenient parameters.
|
|
@@ -78,8 +78,8 @@ def create_write_strategy(
|
|
|
78
78
|
|
|
79
79
|
def _create_tiled_write_strategy(
|
|
80
80
|
write_type: SupportedWriteType,
|
|
81
|
-
write_func:
|
|
82
|
-
write_extension:
|
|
81
|
+
write_func: WriteFunc | None,
|
|
82
|
+
write_extension: str | None,
|
|
83
83
|
write_func_kwargs: dict[str, Any],
|
|
84
84
|
) -> WriteStrategy:
|
|
85
85
|
"""
|
|
@@ -130,7 +130,7 @@ def _create_tiled_write_strategy(
|
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
def select_write_func(
|
|
133
|
-
write_type: SupportedWriteType, write_func:
|
|
133
|
+
write_type: SupportedWriteType, write_func: WriteFunc | None = None
|
|
134
134
|
) -> WriteFunc:
|
|
135
135
|
"""
|
|
136
136
|
Return a function to write images.
|
|
@@ -177,7 +177,7 @@ def select_write_func(
|
|
|
177
177
|
|
|
178
178
|
|
|
179
179
|
def select_write_extension(
|
|
180
|
-
write_type: SupportedWriteType, write_extension:
|
|
180
|
+
write_type: SupportedWriteType, write_extension: str | None = None
|
|
181
181
|
) -> str:
|
|
182
182
|
"""
|
|
183
183
|
Return an extension to add to file paths.
|