careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +6 -12
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +843 -29
- careamics/config/data/data_model.py +1 -2
- careamics/config/data/ng_data_model.py +1 -2
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/lightning_module.py +161 -61
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/xy_random_rotate90.py +1 -1
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from pprint import pformat
|
|
9
|
-
from typing import Annotated, Any, Literal, Union
|
|
9
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
10
10
|
from warnings import warn
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
@@ -19,7 +19,6 @@ from pydantic import (
|
|
|
19
19
|
field_validator,
|
|
20
20
|
model_validator,
|
|
21
21
|
)
|
|
22
|
-
from typing_extensions import Self
|
|
23
22
|
|
|
24
23
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
25
24
|
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
6
|
from pprint import pformat
|
|
7
|
-
from typing import Annotated, Any, Literal, Union
|
|
7
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
8
8
|
from warnings import warn
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
@@ -17,7 +17,6 @@ from pydantic import (
|
|
|
17
17
|
field_validator,
|
|
18
18
|
model_validator,
|
|
19
19
|
)
|
|
20
|
-
from typing_extensions import Self
|
|
21
20
|
|
|
22
21
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
23
22
|
from ..validators import check_axes_validity
|
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any, Literal, Union
|
|
5
|
+
from typing import Any, Literal, Self, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
-
from typing_extensions import Self
|
|
9
8
|
|
|
10
9
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
11
10
|
|
|
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
|
|
|
50
50
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
51
51
|
|
|
52
52
|
# TODO remove and use as parameters to the likelihood functions?
|
|
53
|
-
data_mean: Tensor =
|
|
53
|
+
data_mean: Tensor | None = None
|
|
54
54
|
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
55
55
|
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
56
56
|
|
|
57
57
|
# TODO remove and use as parameters to the likelihood functions?
|
|
58
|
-
data_std: Tensor =
|
|
58
|
+
data_std: Tensor | None = None
|
|
59
59
|
"""The standard deviation of the data, used to unnormalize data for noise
|
|
60
60
|
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
careamics/config/loss_model.py
CHANGED
|
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
|
|
|
35
35
|
validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
loss_type: Literal[
|
|
38
|
+
loss_type: Literal[
|
|
39
|
+
"hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
|
|
40
|
+
]
|
|
39
41
|
"""Type of loss to use for LVAE."""
|
|
40
42
|
|
|
41
43
|
reconstruction_weight: float = 1.0
|
|
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
|
|
|
50
52
|
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
51
53
|
kl_params: KLLossConfig = KLLossConfig()
|
|
52
54
|
"""KL loss configuration."""
|
|
53
|
-
|
|
55
|
+
# TODO revisit weights for the losses
|
|
54
56
|
# TODO: remove?
|
|
55
57
|
non_stochastic: bool = False
|
|
56
58
|
"""Whether to sample latents and compute KL."""
|
|
59
|
+
|
|
60
|
+
# TODO what are the correct parameters for HDN ?
|
careamics/config/nm_model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Noise models config."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Annotated, Literal, Union
|
|
4
|
+
from typing import Annotated, Literal, Self, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -11,6 +11,7 @@ from pydantic import (
|
|
|
11
11
|
Field,
|
|
12
12
|
PlainSerializer,
|
|
13
13
|
PlainValidator,
|
|
14
|
+
model_validator,
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
@@ -86,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
86
87
|
tol: float = Field(default=1e-10)
|
|
87
88
|
"""Tolerance used in the computation of the noise model likelihood."""
|
|
88
89
|
|
|
90
|
+
@model_validator(mode="after")
|
|
91
|
+
def validate_path(self: Self) -> Self:
|
|
92
|
+
"""Validate that the path points to a valid .npz file if provided.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
Self
|
|
97
|
+
Returns itself.
|
|
98
|
+
|
|
99
|
+
Raises
|
|
100
|
+
------
|
|
101
|
+
ValueError
|
|
102
|
+
If the path is provided but does not point to a valid .npz file.
|
|
103
|
+
"""
|
|
104
|
+
if self.path is not None:
|
|
105
|
+
path = Path(self.path)
|
|
106
|
+
if not path.exists():
|
|
107
|
+
raise ValueError(f"Path {path} does not exist.")
|
|
108
|
+
if path.suffix != ".npz":
|
|
109
|
+
raise ValueError(f"Path {path} must point to a .npz file.")
|
|
110
|
+
if not path.is_file():
|
|
111
|
+
raise ValueError(f"Path {path} must point to a file.")
|
|
112
|
+
return self
|
|
113
|
+
|
|
89
114
|
# @model_validator(mode="after")
|
|
90
115
|
# def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
91
116
|
# """Validate paths provided in the config.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal, Self
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -13,7 +13,6 @@ from pydantic import (
|
|
|
13
13
|
model_validator,
|
|
14
14
|
)
|
|
15
15
|
from torch import optim
|
|
16
|
-
from typing_extensions import Self
|
|
17
16
|
|
|
18
17
|
from careamics.utils.torch_utils import filter_parameters
|
|
19
18
|
|
|
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
|
|
|
26
26
|
MUSPLIT = "musplit"
|
|
27
27
|
"""An image splitting approach based on ladder VAE architectures."""
|
|
28
28
|
|
|
29
|
+
MICROSPLIT = "microsplit"
|
|
30
|
+
"""A micro-level image splitting approach based on ladder VAE architectures."""
|
|
31
|
+
|
|
29
32
|
DENOISPLIT = "denoisplit"
|
|
30
33
|
"""An image splitting and denoising approach based on ladder VAE architectures."""
|
|
31
34
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# SEG = "segmentation"
|
|
35
|
+
HDN = "hdn"
|
|
36
|
+
"""Hierarchical Denoising Network, an unsupervised denoising algorithm"""
|
|
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
|
|
|
21
21
|
MAE = "mae"
|
|
22
22
|
N2V = "n2v"
|
|
23
23
|
# PN2V = "pn2v"
|
|
24
|
-
|
|
24
|
+
HDN = "hdn"
|
|
25
25
|
MUSPLIT = "musplit"
|
|
26
|
+
MICROSPLIT = "microsplit"
|
|
26
27
|
DENOISPLIT = "denoisplit"
|
|
27
|
-
DENOISPLIT_MUSPLIT =
|
|
28
|
+
DENOISPLIT_MUSPLIT = (
|
|
29
|
+
"denoisplit_musplit" # TODO refac losses, leave only microsplit
|
|
30
|
+
)
|
|
28
31
|
# CE = "ce"
|
|
29
32
|
# DICE = "dice"
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal
|
|
6
|
+
from typing import Literal
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
9
|
|
|
10
10
|
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
11
11
|
|
|
@@ -29,26 +29,15 @@ class TrainingConfig(BaseModel):
|
|
|
29
29
|
model_config = ConfigDict(
|
|
30
30
|
validate_assignment=True,
|
|
31
31
|
)
|
|
32
|
+
lightning_trainer_config: dict | None = None
|
|
33
|
+
"""Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
|
|
34
|
+
Trainer class"""
|
|
32
35
|
|
|
33
|
-
num_epochs: int = Field(default=20, ge=1)
|
|
34
|
-
"""Number of epochs, greater than 0."""
|
|
35
|
-
|
|
36
|
-
precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
|
|
37
|
-
"""Numerical precision"""
|
|
38
|
-
max_steps: int = Field(default=-1, ge=-1)
|
|
39
|
-
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
|
-
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
|
-
"""Validation step frequency."""
|
|
42
|
-
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
43
|
-
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
44
|
-
gradient_clip_val: Union[int, float] | None = None
|
|
45
|
-
"""The value to which to clip the gradient"""
|
|
46
|
-
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
47
|
-
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
48
36
|
logger: Literal["wandb", "tensorboard"] | None = None
|
|
49
37
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
50
38
|
loggers are defined in SupportedLogger."""
|
|
51
39
|
|
|
40
|
+
# Only basic callbacks
|
|
52
41
|
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
53
42
|
"""Checkpoint callback configuration, following PyTorch Lightning Checkpoint
|
|
54
43
|
callback."""
|
|
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
|
|
|
78
67
|
Whether the logger is defined or not.
|
|
79
68
|
"""
|
|
80
69
|
return self.logger is not None
|
|
81
|
-
|
|
82
|
-
@field_validator("max_steps")
|
|
83
|
-
@classmethod
|
|
84
|
-
def validate_max_steps(cls, max_steps: int) -> int:
|
|
85
|
-
"""Validate the max_steps parameter.
|
|
86
|
-
|
|
87
|
-
Parameters
|
|
88
|
-
----------
|
|
89
|
-
max_steps : int
|
|
90
|
-
Maximum number of steps to train for. -1 means no limit.
|
|
91
|
-
|
|
92
|
-
Returns
|
|
93
|
-
-------
|
|
94
|
-
int
|
|
95
|
-
Validated max_steps.
|
|
96
|
-
"""
|
|
97
|
-
if max_steps == 0:
|
|
98
|
-
raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
|
|
99
|
-
return max_steps
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Self
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from .transform_model import TransformModel
|
|
9
8
|
|
|
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import zarr
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
|
-
from zarr.storage import
|
|
10
|
+
from zarr.storage import FsspecStore
|
|
11
11
|
|
|
12
12
|
from careamics.config import DataConfig
|
|
13
13
|
from careamics.config.support import SupportedData
|
|
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
|
20
20
|
|
|
21
21
|
# %%
|
|
22
22
|
def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
|
|
23
|
-
store =
|
|
23
|
+
store = FsspecStore.from_url(url=file_path.resolve())
|
|
24
24
|
# create array
|
|
25
25
|
array = zarr.create(
|
|
26
26
|
store=store,
|
|
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
|
|
|
61
61
|
# ### Make sure file exists
|
|
62
62
|
|
|
63
63
|
# %%
|
|
64
|
-
store =
|
|
64
|
+
store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
|
|
65
65
|
|
|
66
66
|
# %%
|
|
67
67
|
list(store.keys())
|
|
@@ -72,7 +72,7 @@ list(store.keys())
|
|
|
72
72
|
|
|
73
73
|
# %%
|
|
74
74
|
class ZarrSource(TypedDict):
|
|
75
|
-
store:
|
|
75
|
+
store: FsspecStore
|
|
76
76
|
data_paths: Sequence[str]
|
|
77
77
|
|
|
78
78
|
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Literal, Union
|
|
3
|
+
from typing import Any, Literal, Self, Union
|
|
4
4
|
|
|
5
5
|
from numpy.typing import DTypeLike, NDArray
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from careamics.dataset.dataset_utils import reshape_array
|
|
9
8
|
from careamics.file_io.read import ReadFunc, read_tiff
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
3
|
+
from typing import Self, Union
|
|
4
4
|
|
|
5
|
+
import validators
|
|
5
6
|
import zarr
|
|
6
|
-
import zarr.storage
|
|
7
7
|
from numpy.typing import NDArray
|
|
8
|
-
from
|
|
8
|
+
from zarr.storage import FsspecStore, LocalStore
|
|
9
9
|
|
|
10
10
|
from careamics.dataset.dataset_utils import reshape_array
|
|
11
11
|
|
|
@@ -15,9 +15,10 @@ class ZarrImageStack:
|
|
|
15
15
|
A class for extracting patches from an image stack that is stored as a zarr array.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
# TODO:
|
|
19
|
-
#
|
|
20
|
-
|
|
18
|
+
# TODO: We should keep store type narrow
|
|
19
|
+
# - in zarr v3, does zarr.storage.Store exists and has the path attribute?
|
|
20
|
+
# - can we declare a narrow type rather than a union?
|
|
21
|
+
def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
|
|
21
22
|
self._store = store
|
|
22
23
|
self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
|
|
23
24
|
# TODO: validate axes
|
|
@@ -46,8 +47,33 @@ class ZarrImageStack:
|
|
|
46
47
|
Assumes the path only contains 1 image.
|
|
47
48
|
|
|
48
49
|
Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
path : Union[Path, str]
|
|
54
|
+
Path to the root of the OME-Zarr, local file or url.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
ZarrImageStack
|
|
59
|
+
Initialised ZarrImageStack.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If the path does not exist or is not a valid URL.
|
|
65
|
+
ValueError
|
|
66
|
+
If the OME-Zarr at the path does not contain the attribute 'multiscales'.
|
|
49
67
|
"""
|
|
50
|
-
|
|
68
|
+
if Path(path).is_file():
|
|
69
|
+
store = zarr.storage.LocalStore(root=Path(path).resolve())
|
|
70
|
+
elif validators.url(path):
|
|
71
|
+
store = zarr.storage.FsspecStore.from_url(url=path)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Path '{path}' is neither an existing file nor a valid URL."
|
|
75
|
+
)
|
|
76
|
+
|
|
51
77
|
group = zarr.open_group(store=store, mode="r")
|
|
52
78
|
if "multiscales" not in group.attrs:
|
|
53
79
|
raise ValueError(
|
|
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
38
38
|
|
|
39
39
|
>>> from typing import TypedDict
|
|
40
40
|
|
|
41
|
-
>>> from zarr.storage import
|
|
41
|
+
>>> from zarr.storage import FsspecStore
|
|
42
42
|
|
|
43
43
|
>>> from careamics.config import DataConfig
|
|
44
44
|
>>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
|
|
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
46
46
|
>>> # Define a zarr source
|
|
47
47
|
>>> # It encompasses multiple arguments that determine what data will be loaded
|
|
48
48
|
>>> class ZarrSource(TypedDict):
|
|
49
|
-
... store:
|
|
49
|
+
... store: FsspecStore
|
|
50
50
|
... data_paths: Sequence[str]
|
|
51
51
|
|
|
52
52
|
>>> def custom_image_stack_loader(
|
careamics/lightning/__init__.py
CHANGED
|
@@ -1,18 +1,32 @@
|
|
|
1
1
|
"""CAREamics PyTorch Lightning modules."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"FCNModule",
|
|
5
6
|
"HyperParametersCallback",
|
|
7
|
+
"MicroSplitDataModule",
|
|
6
8
|
"PredictDataModule",
|
|
7
9
|
"ProgressBarCallback",
|
|
8
10
|
"TrainDataModule",
|
|
9
11
|
"VAEModule",
|
|
10
12
|
"create_careamics_module",
|
|
13
|
+
"create_microsplit_predict_datamodule",
|
|
14
|
+
"create_microsplit_train_datamodule",
|
|
11
15
|
"create_predict_datamodule",
|
|
12
16
|
"create_train_datamodule",
|
|
17
|
+
"create_unet_based_module",
|
|
18
|
+
"create_vae_based_module",
|
|
13
19
|
]
|
|
14
20
|
|
|
15
|
-
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
21
|
+
from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
|
|
16
22
|
from .lightning_module import FCNModule, VAEModule, create_careamics_module
|
|
23
|
+
from .microsplit_data_module import (
|
|
24
|
+
MicroSplitDataModule,
|
|
25
|
+
create_microsplit_predict_datamodule,
|
|
26
|
+
create_microsplit_train_datamodule,
|
|
27
|
+
)
|
|
17
28
|
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
18
|
-
from .train_data_module import
|
|
29
|
+
from .train_data_module import (
|
|
30
|
+
TrainDataModule,
|
|
31
|
+
create_train_datamodule,
|
|
32
|
+
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Callbacks module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
4
5
|
"HyperParametersCallback",
|
|
5
6
|
"PredictionWriterCallback",
|
|
6
7
|
"ProgressBarCallback",
|
|
7
8
|
]
|
|
8
9
|
|
|
10
|
+
from .data_stats_callback import DataStatsCallback
|
|
9
11
|
from .hyperparameters_callback import HyperParametersCallback
|
|
10
12
|
from .prediction_writer_callback import PredictionWriterCallback
|
|
11
13
|
from .progress_bar_callback import ProgressBarCallback
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Data statistics callback."""
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DataStatsCallback(Callback):
|
|
8
|
+
"""Callback to update model's data statistics from datamodule.
|
|
9
|
+
|
|
10
|
+
This callback ensures that the model has access to the data statistics (mean and std)
|
|
11
|
+
calculated by the datamodule before training starts.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
|
|
15
|
+
"""Called when trainer is setting up."""
|
|
16
|
+
if stage == "fit":
|
|
17
|
+
# Get data statistics from datamodule
|
|
18
|
+
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
|
|
19
|
+
|
|
20
|
+
# Set data statistics in the model's likelihood module
|
|
21
|
+
module.noise_model_likelihood.set_data_stats(
|
|
22
|
+
data_mean=data_mean["target"], data_std=data_std["target"]
|
|
23
|
+
)
|