careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +55 -61
- careamics/cli/conf.py +24 -9
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +53 -18
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +15 -11
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +892 -78
- careamics/config/data/data_model.py +7 -14
- careamics/config/data/ng_data_model.py +8 -15
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -11
- careamics/config/likelihood_model.py +4 -4
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +30 -7
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +8 -38
- careamics/config/transformations/normalize_model.py +3 -4
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +166 -68
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +16 -9
- careamics/lightning/train_data_module.py +29 -18
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +94 -9
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +12 -8
- careamics/models/layers.py +5 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +4 -6
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
from pprint import pformat
|
|
9
|
-
from typing import Annotated, Any, Literal,
|
|
9
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
10
10
|
from warnings import warn
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
@@ -19,7 +19,6 @@ from pydantic import (
|
|
|
19
19
|
field_validator,
|
|
20
20
|
model_validator,
|
|
21
21
|
)
|
|
22
|
-
from typing_extensions import Self
|
|
23
22
|
|
|
24
23
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
25
24
|
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
@@ -109,22 +108,16 @@ class DataConfig(BaseModel):
|
|
|
109
108
|
"""Batch size for training."""
|
|
110
109
|
|
|
111
110
|
# Optional fields
|
|
112
|
-
image_means:
|
|
113
|
-
default=None, min_length=0, max_length=32
|
|
114
|
-
)
|
|
111
|
+
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
115
112
|
"""Means of the data across channels, used for normalization."""
|
|
116
113
|
|
|
117
|
-
image_stds:
|
|
114
|
+
image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
118
115
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
119
116
|
|
|
120
|
-
target_means:
|
|
121
|
-
default=None, min_length=0, max_length=32
|
|
122
|
-
)
|
|
117
|
+
target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
123
118
|
"""Means of the target data across channels, used for normalization."""
|
|
124
119
|
|
|
125
|
-
target_stds:
|
|
126
|
-
default=None, min_length=0, max_length=32
|
|
127
|
-
)
|
|
120
|
+
target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
128
121
|
"""Standard deviations of the target data across channels, used for
|
|
129
122
|
normalization."""
|
|
130
123
|
|
|
@@ -388,8 +381,8 @@ class DataConfig(BaseModel):
|
|
|
388
381
|
self,
|
|
389
382
|
image_means: Union[NDArray, tuple, list, None],
|
|
390
383
|
image_stds: Union[NDArray, tuple, list, None],
|
|
391
|
-
target_means:
|
|
392
|
-
target_stds:
|
|
384
|
+
target_means: Union[NDArray, tuple, list, None] | None = None,
|
|
385
|
+
target_stds: Union[NDArray, tuple, list, None] | None = None,
|
|
393
386
|
) -> None:
|
|
394
387
|
"""
|
|
395
388
|
Set mean and standard deviation of the data across channels.
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
6
|
from pprint import pformat
|
|
7
|
-
from typing import Annotated, Any, Literal,
|
|
7
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
8
8
|
from warnings import warn
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
@@ -17,7 +17,6 @@ from pydantic import (
|
|
|
17
17
|
field_validator,
|
|
18
18
|
model_validator,
|
|
19
19
|
)
|
|
20
|
-
from typing_extensions import Self
|
|
21
20
|
|
|
22
21
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
23
22
|
from ..validators import check_axes_validity
|
|
@@ -106,22 +105,16 @@ class NGDataConfig(BaseModel):
|
|
|
106
105
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
107
106
|
"""Batch size for training."""
|
|
108
107
|
|
|
109
|
-
image_means:
|
|
110
|
-
default=None, min_length=0, max_length=32
|
|
111
|
-
)
|
|
108
|
+
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
112
109
|
"""Means of the data across channels, used for normalization."""
|
|
113
110
|
|
|
114
|
-
image_stds:
|
|
111
|
+
image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
115
112
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
116
113
|
|
|
117
|
-
target_means:
|
|
118
|
-
default=None, min_length=0, max_length=32
|
|
119
|
-
)
|
|
114
|
+
target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
120
115
|
"""Means of the target data across channels, used for normalization."""
|
|
121
116
|
|
|
122
|
-
target_stds:
|
|
123
|
-
default=None, min_length=0, max_length=32
|
|
124
|
-
)
|
|
117
|
+
target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
125
118
|
"""Standard deviations of the target data across channels, used for
|
|
126
119
|
normalization."""
|
|
127
120
|
|
|
@@ -148,7 +141,7 @@ class NGDataConfig(BaseModel):
|
|
|
148
141
|
test_dataloader_params: dict[str, Any] = Field(default={})
|
|
149
142
|
"""Dictionary of PyTorch test dataloader parameters."""
|
|
150
143
|
|
|
151
|
-
seed:
|
|
144
|
+
seed: int | None = Field(default=None, gt=0)
|
|
152
145
|
"""Random seed for reproducibility."""
|
|
153
146
|
|
|
154
147
|
@field_validator("axes")
|
|
@@ -330,8 +323,8 @@ class NGDataConfig(BaseModel):
|
|
|
330
323
|
self,
|
|
331
324
|
image_means: Union[NDArray, tuple, list, None],
|
|
332
325
|
image_stds: Union[NDArray, tuple, list, None],
|
|
333
|
-
target_means:
|
|
334
|
-
target_stds:
|
|
326
|
+
target_means: Union[NDArray, tuple, list, None] | None = None,
|
|
327
|
+
target_stds: Union[NDArray, tuple, list, None] | None = None,
|
|
335
328
|
) -> None:
|
|
336
329
|
"""
|
|
337
330
|
Set mean and standard deviation of the data across channels.
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Sequential patching Pydantic model."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
from pydantic import Field, ValidationInfo, field_validator
|
|
7
6
|
|
|
@@ -24,7 +23,7 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
24
23
|
dimension, and the number of dimensions be either 2 or 3.
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
|
-
overlaps:
|
|
26
|
+
overlaps: Sequence[int] | None = Field(
|
|
28
27
|
default=None,
|
|
29
28
|
min_length=2,
|
|
30
29
|
max_length=3,
|
|
@@ -37,8 +36,8 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
37
36
|
@field_validator("overlaps")
|
|
38
37
|
@classmethod
|
|
39
38
|
def overlap_smaller_than_patch_size(
|
|
40
|
-
cls, overlaps:
|
|
41
|
-
) ->
|
|
39
|
+
cls, overlaps: Sequence[int] | None, values: ValidationInfo
|
|
40
|
+
) -> Sequence[int] | None:
|
|
42
41
|
"""
|
|
43
42
|
Validate overlap.
|
|
44
43
|
|
|
@@ -78,7 +77,7 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
78
77
|
|
|
79
78
|
@field_validator("overlaps")
|
|
80
79
|
@classmethod
|
|
81
|
-
def overlap_even(cls, overlaps:
|
|
80
|
+
def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
|
|
82
81
|
"""
|
|
83
82
|
Validate overlaps.
|
|
84
83
|
|
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Self, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
-
from typing_extensions import Self
|
|
9
8
|
|
|
10
9
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
11
10
|
|
|
@@ -18,12 +17,10 @@ class InferenceConfig(BaseModel):
|
|
|
18
17
|
data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
|
|
19
18
|
"""Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
|
|
20
19
|
|
|
21
|
-
tile_size:
|
|
22
|
-
default=None, min_length=2, max_length=3
|
|
23
|
-
)
|
|
20
|
+
tile_size: Union[list[int]] | None = Field(default=None, min_length=2, max_length=3)
|
|
24
21
|
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
22
|
|
|
26
|
-
tile_overlap:
|
|
23
|
+
tile_overlap: Union[list[int]] | None = Field(
|
|
27
24
|
default=None, min_length=2, max_length=3
|
|
28
25
|
)
|
|
29
26
|
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
@@ -48,8 +45,8 @@ class InferenceConfig(BaseModel):
|
|
|
48
45
|
@field_validator("tile_overlap")
|
|
49
46
|
@classmethod
|
|
50
47
|
def all_elements_non_zero_even(
|
|
51
|
-
cls, tile_overlap:
|
|
52
|
-
) ->
|
|
48
|
+
cls, tile_overlap: list[int] | None
|
|
49
|
+
) -> list[int] | None:
|
|
53
50
|
"""
|
|
54
51
|
Validate tile overlap.
|
|
55
52
|
|
|
@@ -86,9 +83,7 @@ class InferenceConfig(BaseModel):
|
|
|
86
83
|
|
|
87
84
|
@field_validator("tile_size")
|
|
88
85
|
@classmethod
|
|
89
|
-
def tile_min_8_power_of_2(
|
|
90
|
-
cls, tile_list: Optional[list[int]]
|
|
91
|
-
) -> Optional[list[int]]:
|
|
86
|
+
def tile_min_8_power_of_2(cls, tile_list: list[int] | None) -> list[int] | None:
|
|
92
87
|
"""
|
|
93
88
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
94
89
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Likelihood model."""
|
|
2
2
|
|
|
3
|
-
from typing import Annotated, Literal,
|
|
3
|
+
from typing import Annotated, Literal, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -31,7 +31,7 @@ class GaussianLikelihoodConfig(BaseModel):
|
|
|
31
31
|
|
|
32
32
|
model_config = ConfigDict(validate_assignment=True)
|
|
33
33
|
|
|
34
|
-
predict_logvar:
|
|
34
|
+
predict_logvar: Literal["pixelwise"] | None = None
|
|
35
35
|
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
36
36
|
is not computed."""
|
|
37
37
|
|
|
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
|
|
|
50
50
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
51
51
|
|
|
52
52
|
# TODO remove and use as parameters to the likelihood functions?
|
|
53
|
-
data_mean: Tensor =
|
|
53
|
+
data_mean: Tensor | None = None
|
|
54
54
|
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
55
55
|
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
56
56
|
|
|
57
57
|
# TODO remove and use as parameters to the likelihood functions?
|
|
58
|
-
data_std: Tensor =
|
|
58
|
+
data_std: Tensor | None = None
|
|
59
59
|
"""The standard deviation of the data, used to unnormalize data for noise
|
|
60
60
|
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
careamics/config/loss_model.py
CHANGED
|
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
|
|
|
35
35
|
validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
loss_type: Literal[
|
|
38
|
+
loss_type: Literal[
|
|
39
|
+
"hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
|
|
40
|
+
]
|
|
39
41
|
"""Type of loss to use for LVAE."""
|
|
40
42
|
|
|
41
43
|
reconstruction_weight: float = 1.0
|
|
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
|
|
|
50
52
|
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
51
53
|
kl_params: KLLossConfig = KLLossConfig()
|
|
52
54
|
"""KL loss configuration."""
|
|
53
|
-
|
|
55
|
+
# TODO revisit weights for the losses
|
|
54
56
|
# TODO: remove?
|
|
55
57
|
non_stochastic: bool = False
|
|
56
58
|
"""Whether to sample latents and compute KL."""
|
|
59
|
+
|
|
60
|
+
# TODO what are the correct parameters for HDN ?
|
careamics/config/nm_model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Noise models config."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Annotated, Literal,
|
|
4
|
+
from typing import Annotated, Literal, Self, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -11,6 +11,7 @@ from pydantic import (
|
|
|
11
11
|
Field,
|
|
12
12
|
PlainSerializer,
|
|
13
13
|
PlainValidator,
|
|
14
|
+
model_validator,
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
@@ -42,21 +43,19 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
42
43
|
# model type
|
|
43
44
|
model_type: Literal["GaussianMixtureNoiseModel"]
|
|
44
45
|
|
|
45
|
-
path:
|
|
46
|
+
path: Union[Path, str] | None = None
|
|
46
47
|
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
47
48
|
`train` method."""
|
|
48
49
|
|
|
49
50
|
# TODO remove and use as parameters to the NM functions?
|
|
50
|
-
signal:
|
|
51
|
+
signal: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
|
|
51
52
|
"""Path to the file containing signal or respective numpy array."""
|
|
52
53
|
|
|
53
54
|
# TODO remove and use as parameters to the NM functions?
|
|
54
|
-
observation:
|
|
55
|
-
default=None, exclude=True
|
|
56
|
-
)
|
|
55
|
+
observation: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
|
|
57
56
|
"""Path to the file containing observation or respective numpy array."""
|
|
58
57
|
|
|
59
|
-
weight:
|
|
58
|
+
weight: Array | None = None
|
|
60
59
|
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
61
60
|
describing the GMM noise model, with each row corresponding to one
|
|
62
61
|
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
@@ -88,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
88
87
|
tol: float = Field(default=1e-10)
|
|
89
88
|
"""Tolerance used in the computation of the noise model likelihood."""
|
|
90
89
|
|
|
90
|
+
@model_validator(mode="after")
|
|
91
|
+
def validate_path(self: Self) -> Self:
|
|
92
|
+
"""Validate that the path points to a valid .npz file if provided.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
Self
|
|
97
|
+
Returns itself.
|
|
98
|
+
|
|
99
|
+
Raises
|
|
100
|
+
------
|
|
101
|
+
ValueError
|
|
102
|
+
If the path is provided but does not point to a valid .npz file.
|
|
103
|
+
"""
|
|
104
|
+
if self.path is not None:
|
|
105
|
+
path = Path(self.path)
|
|
106
|
+
if not path.exists():
|
|
107
|
+
raise ValueError(f"Path {path} does not exist.")
|
|
108
|
+
if path.suffix != ".npz":
|
|
109
|
+
raise ValueError(f"Path {path} must point to a .npz file.")
|
|
110
|
+
if not path.is_file():
|
|
111
|
+
raise ValueError(f"Path {path} must point to a file.")
|
|
112
|
+
return self
|
|
113
|
+
|
|
91
114
|
# @model_validator(mode="after")
|
|
92
115
|
# def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
93
116
|
# """Validate paths provided in the config.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal, Self
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -13,7 +13,6 @@ from pydantic import (
|
|
|
13
13
|
model_validator,
|
|
14
14
|
)
|
|
15
15
|
from torch import optim
|
|
16
|
-
from typing_extensions import Self
|
|
17
16
|
|
|
18
17
|
from careamics.utils.torch_utils import filter_parameters
|
|
19
18
|
|
|
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
|
|
|
26
26
|
MUSPLIT = "musplit"
|
|
27
27
|
"""An image splitting approach based on ladder VAE architectures."""
|
|
28
28
|
|
|
29
|
+
MICROSPLIT = "microsplit"
|
|
30
|
+
"""A micro-level image splitting approach based on ladder VAE architectures."""
|
|
31
|
+
|
|
29
32
|
DENOISPLIT = "denoisplit"
|
|
30
33
|
"""An image splitting and denoising approach based on ladder VAE architectures."""
|
|
31
34
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# SEG = "segmentation"
|
|
35
|
+
HDN = "hdn"
|
|
36
|
+
"""Hierarchical Denoising Network, an unsupervised denoising algorithm"""
|
|
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
|
|
|
21
21
|
MAE = "mae"
|
|
22
22
|
N2V = "n2v"
|
|
23
23
|
# PN2V = "pn2v"
|
|
24
|
-
|
|
24
|
+
HDN = "hdn"
|
|
25
25
|
MUSPLIT = "musplit"
|
|
26
|
+
MICROSPLIT = "microsplit"
|
|
26
27
|
DENOISPLIT = "denoisplit"
|
|
27
|
-
DENOISPLIT_MUSPLIT =
|
|
28
|
+
DENOISPLIT_MUSPLIT = (
|
|
29
|
+
"denoisplit_musplit" # TODO refac losses, leave only microsplit
|
|
30
|
+
)
|
|
28
31
|
# CE = "ce"
|
|
29
32
|
# DICE = "dice"
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal
|
|
6
|
+
from typing import Literal
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
9
|
|
|
10
10
|
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
11
11
|
|
|
@@ -29,31 +29,20 @@ class TrainingConfig(BaseModel):
|
|
|
29
29
|
model_config = ConfigDict(
|
|
30
30
|
validate_assignment=True,
|
|
31
31
|
)
|
|
32
|
+
lightning_trainer_config: dict | None = None
|
|
33
|
+
"""Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
|
|
34
|
+
Trainer class"""
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
"""Number of epochs, greater than 0."""
|
|
35
|
-
|
|
36
|
-
precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
|
|
37
|
-
"""Numerical precision"""
|
|
38
|
-
max_steps: int = Field(default=-1, ge=-1)
|
|
39
|
-
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
|
-
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
|
-
"""Validation step frequency."""
|
|
42
|
-
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
43
|
-
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
44
|
-
gradient_clip_val: Optional[Union[int, float]] = None
|
|
45
|
-
"""The value to which to clip the gradient"""
|
|
46
|
-
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
47
|
-
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
48
|
-
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
36
|
+
logger: Literal["wandb", "tensorboard"] | None = None
|
|
49
37
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
50
38
|
loggers are defined in SupportedLogger."""
|
|
51
39
|
|
|
40
|
+
# Only basic callbacks
|
|
52
41
|
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
53
42
|
"""Checkpoint callback configuration, following PyTorch Lightning Checkpoint
|
|
54
43
|
callback."""
|
|
55
44
|
|
|
56
|
-
early_stopping_callback:
|
|
45
|
+
early_stopping_callback: EarlyStoppingModel | None = Field(
|
|
57
46
|
default=None, validate_default=True
|
|
58
47
|
)
|
|
59
48
|
"""Early stopping callback configuration, following PyTorch Lightning Checkpoint
|
|
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
|
|
|
78
67
|
Whether the logger is defined or not.
|
|
79
68
|
"""
|
|
80
69
|
return self.logger is not None
|
|
81
|
-
|
|
82
|
-
@field_validator("max_steps")
|
|
83
|
-
@classmethod
|
|
84
|
-
def validate_max_steps(cls, max_steps: int) -> int:
|
|
85
|
-
"""Validate the max_steps parameter.
|
|
86
|
-
|
|
87
|
-
Parameters
|
|
88
|
-
----------
|
|
89
|
-
max_steps : int
|
|
90
|
-
Maximum number of steps to train for. -1 means no limit.
|
|
91
|
-
|
|
92
|
-
Returns
|
|
93
|
-
-------
|
|
94
|
-
int
|
|
95
|
-
Validated max_steps.
|
|
96
|
-
"""
|
|
97
|
-
if max_steps == 0:
|
|
98
|
-
raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
|
|
99
|
-
return max_steps
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal,
|
|
3
|
+
from typing import Literal, Self
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from .transform_model import TransformModel
|
|
9
8
|
|
|
@@ -31,8 +30,8 @@ class NormalizeModel(TransformModel):
|
|
|
31
30
|
name: Literal["Normalize"] = "Normalize"
|
|
32
31
|
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
32
|
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
-
target_means:
|
|
35
|
-
target_stds:
|
|
33
|
+
target_means: list | None = Field(default=None, min_length=0, max_length=32)
|
|
34
|
+
target_stds: list | None = Field(default=None, min_length=0, max_length=32)
|
|
36
35
|
|
|
37
36
|
@model_validator(mode="after")
|
|
38
37
|
def validate_means_stds(self: Self) -> Self:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the XYFlip transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field
|
|
6
6
|
|
|
@@ -40,4 +40,4 @@ class XYFlipModel(TransformModel):
|
|
|
40
40
|
ge=0,
|
|
41
41
|
le=1,
|
|
42
42
|
)
|
|
43
|
-
seed:
|
|
43
|
+
seed: int | None = None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field
|
|
6
6
|
|
|
@@ -32,4 +32,4 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
32
32
|
ge=0,
|
|
33
33
|
le=1,
|
|
34
34
|
)
|
|
35
|
-
seed:
|
|
35
|
+
seed: int | None = None
|
|
@@ -5,7 +5,6 @@ These functions are used to validate dimensions and axes of inputs.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from collections.abc import Sequence
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
_AXES = "STCZYX"
|
|
11
10
|
|
|
@@ -80,7 +79,7 @@ def value_ge_than_8_power_of_2(
|
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
def patch_size_ge_than_8_power_of_2(
|
|
83
|
-
patch_list:
|
|
82
|
+
patch_list: Sequence[int] | None,
|
|
84
83
|
) -> None:
|
|
85
84
|
"""
|
|
86
85
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable, Generator
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Union
|
|
8
8
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import get_worker_info
|
|
@@ -21,9 +21,9 @@ logger = get_logger(__name__)
|
|
|
21
21
|
def iterate_over_files(
|
|
22
22
|
data_config: Union[DataConfig, InferenceConfig],
|
|
23
23
|
data_files: list[Path],
|
|
24
|
-
target_files:
|
|
24
|
+
target_files: list[Path] | None = None,
|
|
25
25
|
read_source_func: Callable = read_tiff,
|
|
26
|
-
) -> Generator[tuple[NDArray,
|
|
26
|
+
) -> Generator[tuple[NDArray, NDArray | None], None, None]:
|
|
27
27
|
"""Iterate over data source and yield whole reshaped images.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import copy
|
|
6
6
|
from collections.abc import Callable
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any,
|
|
8
|
+
from typing import Any, Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from torch.utils.data import Dataset
|
|
@@ -49,7 +49,7 @@ class InMemoryDataset(Dataset):
|
|
|
49
49
|
self,
|
|
50
50
|
data_config: DataConfig,
|
|
51
51
|
inputs: Union[np.ndarray, list[Path]],
|
|
52
|
-
input_target:
|
|
52
|
+
input_target: Union[np.ndarray, list[Path]] | None = None,
|
|
53
53
|
read_source_func: Callable = read_tiff,
|
|
54
54
|
**kwargs: Any,
|
|
55
55
|
) -> None:
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
import copy
|
|
6
6
|
from collections.abc import Callable, Generator
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
from torch.utils.data import IterableDataset
|
|
@@ -51,7 +50,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
51
50
|
self,
|
|
52
51
|
data_config: DataConfig,
|
|
53
52
|
src_files: list[Path],
|
|
54
|
-
target_files:
|
|
53
|
+
target_files: list[Path] | None = None,
|
|
55
54
|
read_source_func: Callable = read_tiff,
|
|
56
55
|
) -> None:
|
|
57
56
|
"""Constructors.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Random patching utilities."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Generator
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import zarr
|
|
@@ -13,9 +13,9 @@ from .validate_patch_dimension import validate_patch_dimensions
|
|
|
13
13
|
def extract_patches_random(
|
|
14
14
|
arr: np.ndarray,
|
|
15
15
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
16
|
-
target:
|
|
17
|
-
seed:
|
|
18
|
-
) -> Generator[tuple[np.ndarray,
|
|
16
|
+
target: np.ndarray | None = None,
|
|
17
|
+
seed: int | None = None,
|
|
18
|
+
) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
|
|
19
19
|
"""
|
|
20
20
|
Generate patches from an array in a random manner.
|
|
21
21
|
|
|
@@ -115,8 +115,8 @@ def extract_patches_random_from_chunks(
|
|
|
115
115
|
arr: zarr.Array,
|
|
116
116
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
117
117
|
chunk_size: Union[list[int], tuple[int, ...]],
|
|
118
|
-
chunk_limit:
|
|
119
|
-
seed:
|
|
118
|
+
chunk_limit: int | None = None,
|
|
119
|
+
seed: int | None = None,
|
|
120
120
|
) -> Generator[np.ndarray, None, None]:
|
|
121
121
|
"""
|
|
122
122
|
Generate patches from an array in a random manner.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Sequential patching functions."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from skimage.util import view_as_windows
|
|
@@ -110,7 +110,7 @@ def _compute_patch_views(
|
|
|
110
110
|
window_shape: list[int],
|
|
111
111
|
step: tuple[int, ...],
|
|
112
112
|
output_shape: list[int],
|
|
113
|
-
target:
|
|
113
|
+
target: np.ndarray | None = None,
|
|
114
114
|
) -> np.ndarray:
|
|
115
115
|
"""
|
|
116
116
|
Compute views of an array corresponding to patches.
|
|
@@ -151,8 +151,8 @@ def _compute_patch_views(
|
|
|
151
151
|
def extract_patches_sequential(
|
|
152
152
|
arr: np.ndarray,
|
|
153
153
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
154
|
-
target:
|
|
155
|
-
) -> tuple[np.ndarray,
|
|
154
|
+
target: np.ndarray | None = None,
|
|
155
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
156
156
|
"""
|
|
157
157
|
Generate patches from an array in a sequential manner.
|
|
158
158
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import builtins
|
|
4
4
|
import itertools
|
|
5
5
|
from collections.abc import Generator
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from numpy.typing import NDArray
|
|
@@ -16,7 +16,7 @@ def extract_tiles(
|
|
|
16
16
|
arr: NDArray,
|
|
17
17
|
tile_size: NDArray[np.int_],
|
|
18
18
|
overlaps: NDArray[np.int_],
|
|
19
|
-
padding_kwargs:
|
|
19
|
+
padding_kwargs: dict[str, Any] | None = None,
|
|
20
20
|
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
21
21
|
"""Generate tiles from the input array with specified overlap.
|
|
22
22
|
|