careamics 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +49 -49
- careamics/cli/conf.py +6 -6
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/algorithms/vae_algorithm_model.py +4 -4
- careamics/config/callback_model.py +8 -8
- careamics/config/configuration_factories.py +49 -49
- careamics/config/data/data_model.py +7 -13
- careamics/config/data/ng_data_model.py +8 -14
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -10
- careamics/config/likelihood_model.py +2 -2
- careamics/config/nm_model.py +5 -7
- careamics/config/training_model.py +4 -4
- careamics/config/transformations/normalize_model.py +3 -3
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +12 -14
- careamics/lightning/predict_data_module.py +8 -8
- careamics/lightning/train_data_module.py +11 -11
- careamics/losses/lvae/losses.py +9 -9
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +4 -4
- careamics/models/layers.py +5 -5
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +3 -5
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/RECORD +56 -56
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
6
|
from pprint import pformat
|
|
7
|
-
from typing import Annotated, Any, Literal,
|
|
7
|
+
from typing import Annotated, Any, Literal, Union
|
|
8
8
|
from warnings import warn
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
@@ -106,22 +106,16 @@ class NGDataConfig(BaseModel):
|
|
|
106
106
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
107
107
|
"""Batch size for training."""
|
|
108
108
|
|
|
109
|
-
image_means:
|
|
110
|
-
default=None, min_length=0, max_length=32
|
|
111
|
-
)
|
|
109
|
+
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
112
110
|
"""Means of the data across channels, used for normalization."""
|
|
113
111
|
|
|
114
|
-
image_stds:
|
|
112
|
+
image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
115
113
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
116
114
|
|
|
117
|
-
target_means:
|
|
118
|
-
default=None, min_length=0, max_length=32
|
|
119
|
-
)
|
|
115
|
+
target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
120
116
|
"""Means of the target data across channels, used for normalization."""
|
|
121
117
|
|
|
122
|
-
target_stds:
|
|
123
|
-
default=None, min_length=0, max_length=32
|
|
124
|
-
)
|
|
118
|
+
target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
125
119
|
"""Standard deviations of the target data across channels, used for
|
|
126
120
|
normalization."""
|
|
127
121
|
|
|
@@ -148,7 +142,7 @@ class NGDataConfig(BaseModel):
|
|
|
148
142
|
test_dataloader_params: dict[str, Any] = Field(default={})
|
|
149
143
|
"""Dictionary of PyTorch test dataloader parameters."""
|
|
150
144
|
|
|
151
|
-
seed:
|
|
145
|
+
seed: int | None = Field(default=None, gt=0)
|
|
152
146
|
"""Random seed for reproducibility."""
|
|
153
147
|
|
|
154
148
|
@field_validator("axes")
|
|
@@ -330,8 +324,8 @@ class NGDataConfig(BaseModel):
|
|
|
330
324
|
self,
|
|
331
325
|
image_means: Union[NDArray, tuple, list, None],
|
|
332
326
|
image_stds: Union[NDArray, tuple, list, None],
|
|
333
|
-
target_means:
|
|
334
|
-
target_stds:
|
|
327
|
+
target_means: Union[NDArray, tuple, list, None] | None = None,
|
|
328
|
+
target_stds: Union[NDArray, tuple, list, None] | None = None,
|
|
335
329
|
) -> None:
|
|
336
330
|
"""
|
|
337
331
|
Set mean and standard deviation of the data across channels.
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Sequential patching Pydantic model."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
from pydantic import Field, ValidationInfo, field_validator
|
|
7
6
|
|
|
@@ -24,7 +23,7 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
24
23
|
dimension, and the number of dimensions be either 2 or 3.
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
|
-
overlaps:
|
|
26
|
+
overlaps: Sequence[int] | None = Field(
|
|
28
27
|
default=None,
|
|
29
28
|
min_length=2,
|
|
30
29
|
max_length=3,
|
|
@@ -37,8 +36,8 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
37
36
|
@field_validator("overlaps")
|
|
38
37
|
@classmethod
|
|
39
38
|
def overlap_smaller_than_patch_size(
|
|
40
|
-
cls, overlaps:
|
|
41
|
-
) ->
|
|
39
|
+
cls, overlaps: Sequence[int] | None, values: ValidationInfo
|
|
40
|
+
) -> Sequence[int] | None:
|
|
42
41
|
"""
|
|
43
42
|
Validate overlap.
|
|
44
43
|
|
|
@@ -78,7 +77,7 @@ class _OverlappingPatchedModel(_PatchedModel):
|
|
|
78
77
|
|
|
79
78
|
@field_validator("overlaps")
|
|
80
79
|
@classmethod
|
|
81
|
-
def overlap_even(cls, overlaps:
|
|
80
|
+
def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
|
|
82
81
|
"""
|
|
83
82
|
Validate overlaps.
|
|
84
83
|
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any, Literal,
|
|
5
|
+
from typing import Any, Literal, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
@@ -18,12 +18,10 @@ class InferenceConfig(BaseModel):
|
|
|
18
18
|
data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
|
|
19
19
|
"""Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
|
|
20
20
|
|
|
21
|
-
tile_size:
|
|
22
|
-
default=None, min_length=2, max_length=3
|
|
23
|
-
)
|
|
21
|
+
tile_size: Union[list[int]] | None = Field(default=None, min_length=2, max_length=3)
|
|
24
22
|
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
23
|
|
|
26
|
-
tile_overlap:
|
|
24
|
+
tile_overlap: Union[list[int]] | None = Field(
|
|
27
25
|
default=None, min_length=2, max_length=3
|
|
28
26
|
)
|
|
29
27
|
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
@@ -48,8 +46,8 @@ class InferenceConfig(BaseModel):
|
|
|
48
46
|
@field_validator("tile_overlap")
|
|
49
47
|
@classmethod
|
|
50
48
|
def all_elements_non_zero_even(
|
|
51
|
-
cls, tile_overlap:
|
|
52
|
-
) ->
|
|
49
|
+
cls, tile_overlap: list[int] | None
|
|
50
|
+
) -> list[int] | None:
|
|
53
51
|
"""
|
|
54
52
|
Validate tile overlap.
|
|
55
53
|
|
|
@@ -86,9 +84,7 @@ class InferenceConfig(BaseModel):
|
|
|
86
84
|
|
|
87
85
|
@field_validator("tile_size")
|
|
88
86
|
@classmethod
|
|
89
|
-
def tile_min_8_power_of_2(
|
|
90
|
-
cls, tile_list: Optional[list[int]]
|
|
91
|
-
) -> Optional[list[int]]:
|
|
87
|
+
def tile_min_8_power_of_2(cls, tile_list: list[int] | None) -> list[int] | None:
|
|
92
88
|
"""
|
|
93
89
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
94
90
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Likelihood model."""
|
|
2
2
|
|
|
3
|
-
from typing import Annotated, Literal,
|
|
3
|
+
from typing import Annotated, Literal, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -31,7 +31,7 @@ class GaussianLikelihoodConfig(BaseModel):
|
|
|
31
31
|
|
|
32
32
|
model_config = ConfigDict(validate_assignment=True)
|
|
33
33
|
|
|
34
|
-
predict_logvar:
|
|
34
|
+
predict_logvar: Literal["pixelwise"] | None = None
|
|
35
35
|
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
36
36
|
is not computed."""
|
|
37
37
|
|
careamics/config/nm_model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Noise models config."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Annotated, Literal,
|
|
4
|
+
from typing import Annotated, Literal, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -42,21 +42,19 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
42
42
|
# model type
|
|
43
43
|
model_type: Literal["GaussianMixtureNoiseModel"]
|
|
44
44
|
|
|
45
|
-
path:
|
|
45
|
+
path: Union[Path, str] | None = None
|
|
46
46
|
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
47
47
|
`train` method."""
|
|
48
48
|
|
|
49
49
|
# TODO remove and use as parameters to the NM functions?
|
|
50
|
-
signal:
|
|
50
|
+
signal: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
|
|
51
51
|
"""Path to the file containing signal or respective numpy array."""
|
|
52
52
|
|
|
53
53
|
# TODO remove and use as parameters to the NM functions?
|
|
54
|
-
observation:
|
|
55
|
-
default=None, exclude=True
|
|
56
|
-
)
|
|
54
|
+
observation: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
|
|
57
55
|
"""Path to the file containing observation or respective numpy array."""
|
|
58
56
|
|
|
59
|
-
weight:
|
|
57
|
+
weight: Array | None = None
|
|
60
58
|
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
61
59
|
describing the GMM noise model, with each row corresponding to one
|
|
62
60
|
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal,
|
|
6
|
+
from typing import Literal, Union
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
9
9
|
|
|
@@ -41,11 +41,11 @@ class TrainingConfig(BaseModel):
|
|
|
41
41
|
"""Validation step frequency."""
|
|
42
42
|
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
43
43
|
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
44
|
-
gradient_clip_val:
|
|
44
|
+
gradient_clip_val: Union[int, float] | None = None
|
|
45
45
|
"""The value to which to clip the gradient"""
|
|
46
46
|
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
47
47
|
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
48
|
-
logger:
|
|
48
|
+
logger: Literal["wandb", "tensorboard"] | None = None
|
|
49
49
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
50
50
|
loggers are defined in SupportedLogger."""
|
|
51
51
|
|
|
@@ -53,7 +53,7 @@ class TrainingConfig(BaseModel):
|
|
|
53
53
|
"""Checkpoint callback configuration, following PyTorch Lightning Checkpoint
|
|
54
54
|
callback."""
|
|
55
55
|
|
|
56
|
-
early_stopping_callback:
|
|
56
|
+
early_stopping_callback: EarlyStoppingModel | None = Field(
|
|
57
57
|
default=None, validate_default=True
|
|
58
58
|
)
|
|
59
59
|
"""Early stopping callback configuration, following PyTorch Lightning Checkpoint
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the Normalize transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field, model_validator
|
|
6
6
|
from typing_extensions import Self
|
|
@@ -31,8 +31,8 @@ class NormalizeModel(TransformModel):
|
|
|
31
31
|
name: Literal["Normalize"] = "Normalize"
|
|
32
32
|
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
33
|
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
-
target_means:
|
|
35
|
-
target_stds:
|
|
34
|
+
target_means: list | None = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: list | None = Field(default=None, min_length=0, max_length=32)
|
|
36
36
|
|
|
37
37
|
@model_validator(mode="after")
|
|
38
38
|
def validate_means_stds(self: Self) -> Self:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the XYFlip transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field
|
|
6
6
|
|
|
@@ -40,4 +40,4 @@ class XYFlipModel(TransformModel):
|
|
|
40
40
|
ge=0,
|
|
41
41
|
le=1,
|
|
42
42
|
)
|
|
43
|
-
seed:
|
|
43
|
+
seed: int | None = None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal
|
|
4
4
|
|
|
5
5
|
from pydantic import ConfigDict, Field
|
|
6
6
|
|
|
@@ -32,4 +32,4 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
32
32
|
ge=0,
|
|
33
33
|
le=1,
|
|
34
34
|
)
|
|
35
|
-
seed:
|
|
35
|
+
seed: int | None = None
|
|
@@ -5,7 +5,6 @@ These functions are used to validate dimensions and axes of inputs.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from collections.abc import Sequence
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
_AXES = "STCZYX"
|
|
11
10
|
|
|
@@ -80,7 +79,7 @@ def value_ge_than_8_power_of_2(
|
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
def patch_size_ge_than_8_power_of_2(
|
|
83
|
-
patch_list:
|
|
82
|
+
patch_list: Sequence[int] | None,
|
|
84
83
|
) -> None:
|
|
85
84
|
"""
|
|
86
85
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable, Generator
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Union
|
|
8
8
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import get_worker_info
|
|
@@ -21,9 +21,9 @@ logger = get_logger(__name__)
|
|
|
21
21
|
def iterate_over_files(
|
|
22
22
|
data_config: Union[DataConfig, InferenceConfig],
|
|
23
23
|
data_files: list[Path],
|
|
24
|
-
target_files:
|
|
24
|
+
target_files: list[Path] | None = None,
|
|
25
25
|
read_source_func: Callable = read_tiff,
|
|
26
|
-
) -> Generator[tuple[NDArray,
|
|
26
|
+
) -> Generator[tuple[NDArray, NDArray | None], None, None]:
|
|
27
27
|
"""Iterate over data source and yield whole reshaped images.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import copy
|
|
6
6
|
from collections.abc import Callable
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Any,
|
|
8
|
+
from typing import Any, Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from torch.utils.data import Dataset
|
|
@@ -49,7 +49,7 @@ class InMemoryDataset(Dataset):
|
|
|
49
49
|
self,
|
|
50
50
|
data_config: DataConfig,
|
|
51
51
|
inputs: Union[np.ndarray, list[Path]],
|
|
52
|
-
input_target:
|
|
52
|
+
input_target: Union[np.ndarray, list[Path]] | None = None,
|
|
53
53
|
read_source_func: Callable = read_tiff,
|
|
54
54
|
**kwargs: Any,
|
|
55
55
|
) -> None:
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
import copy
|
|
6
6
|
from collections.abc import Callable, Generator
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Optional
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
from torch.utils.data import IterableDataset
|
|
@@ -51,7 +50,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
51
50
|
self,
|
|
52
51
|
data_config: DataConfig,
|
|
53
52
|
src_files: list[Path],
|
|
54
|
-
target_files:
|
|
53
|
+
target_files: list[Path] | None = None,
|
|
55
54
|
read_source_func: Callable = read_tiff,
|
|
56
55
|
) -> None:
|
|
57
56
|
"""Constructors.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Random patching utilities."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Generator
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import zarr
|
|
@@ -13,9 +13,9 @@ from .validate_patch_dimension import validate_patch_dimensions
|
|
|
13
13
|
def extract_patches_random(
|
|
14
14
|
arr: np.ndarray,
|
|
15
15
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
16
|
-
target:
|
|
17
|
-
seed:
|
|
18
|
-
) -> Generator[tuple[np.ndarray,
|
|
16
|
+
target: np.ndarray | None = None,
|
|
17
|
+
seed: int | None = None,
|
|
18
|
+
) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
|
|
19
19
|
"""
|
|
20
20
|
Generate patches from an array in a random manner.
|
|
21
21
|
|
|
@@ -115,8 +115,8 @@ def extract_patches_random_from_chunks(
|
|
|
115
115
|
arr: zarr.Array,
|
|
116
116
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
117
117
|
chunk_size: Union[list[int], tuple[int, ...]],
|
|
118
|
-
chunk_limit:
|
|
119
|
-
seed:
|
|
118
|
+
chunk_limit: int | None = None,
|
|
119
|
+
seed: int | None = None,
|
|
120
120
|
) -> Generator[np.ndarray, None, None]:
|
|
121
121
|
"""
|
|
122
122
|
Generate patches from an array in a random manner.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Sequential patching functions."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from skimage.util import view_as_windows
|
|
@@ -110,7 +110,7 @@ def _compute_patch_views(
|
|
|
110
110
|
window_shape: list[int],
|
|
111
111
|
step: tuple[int, ...],
|
|
112
112
|
output_shape: list[int],
|
|
113
|
-
target:
|
|
113
|
+
target: np.ndarray | None = None,
|
|
114
114
|
) -> np.ndarray:
|
|
115
115
|
"""
|
|
116
116
|
Compute views of an array corresponding to patches.
|
|
@@ -151,8 +151,8 @@ def _compute_patch_views(
|
|
|
151
151
|
def extract_patches_sequential(
|
|
152
152
|
arr: np.ndarray,
|
|
153
153
|
patch_size: Union[list[int], tuple[int, ...]],
|
|
154
|
-
target:
|
|
155
|
-
) -> tuple[np.ndarray,
|
|
154
|
+
target: np.ndarray | None = None,
|
|
155
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
156
156
|
"""
|
|
157
157
|
Generate patches from an array in a sequential manner.
|
|
158
158
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import builtins
|
|
4
4
|
import itertools
|
|
5
5
|
from collections.abc import Generator
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from numpy.typing import NDArray
|
|
@@ -16,7 +16,7 @@ def extract_tiles(
|
|
|
16
16
|
arr: NDArray,
|
|
17
17
|
tile_size: NDArray[np.int_],
|
|
18
18
|
overlaps: NDArray[np.int_],
|
|
19
|
-
padding_kwargs:
|
|
19
|
+
padding_kwargs: dict[str, Any] | None = None,
|
|
20
20
|
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
21
21
|
"""Generate tiles from the input array with specified overlap.
|
|
22
22
|
|
careamics/dataset_ng/dataset.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generic, Literal, NamedTuple,
|
|
4
|
+
from typing import Any, Generic, Literal, NamedTuple, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from numpy.typing import NDArray
|
|
@@ -51,7 +51,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
51
51
|
data_config: NGDataConfig,
|
|
52
52
|
mode: Mode,
|
|
53
53
|
input_extractor: PatchExtractor[GenericImageStack],
|
|
54
|
-
target_extractor:
|
|
54
|
+
target_extractor: PatchExtractor[GenericImageStack] | None = None,
|
|
55
55
|
):
|
|
56
56
|
self.config = data_config
|
|
57
57
|
self.mode = mode
|
|
@@ -115,7 +115,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
115
115
|
|
|
116
116
|
return patching_strategy
|
|
117
117
|
|
|
118
|
-
def _initialize_transforms(self) ->
|
|
118
|
+
def _initialize_transforms(self) -> Compose | None:
|
|
119
119
|
normalize = NormalizeModel(
|
|
120
120
|
image_means=self.input_stats.means,
|
|
121
121
|
image_stds=self.input_stats.stds,
|
careamics/dataset_ng/factory.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
6
|
from numpy.typing import NDArray
|
|
7
7
|
from typing_extensions import ParamSpec
|
|
@@ -48,8 +48,8 @@ class DatasetType(Enum):
|
|
|
48
48
|
def determine_dataset_type(
|
|
49
49
|
data_type: SupportedData,
|
|
50
50
|
in_memory: bool,
|
|
51
|
-
read_func:
|
|
52
|
-
image_stack_loader:
|
|
51
|
+
read_func: ReadFunc | None = None,
|
|
52
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
53
53
|
) -> DatasetType:
|
|
54
54
|
"""Determine what the dataset type should be based on the input arguments.
|
|
55
55
|
|
|
@@ -121,10 +121,10 @@ def create_dataset(
|
|
|
121
121
|
inputs: Any,
|
|
122
122
|
targets: Any,
|
|
123
123
|
in_memory: bool,
|
|
124
|
-
read_func:
|
|
125
|
-
read_kwargs:
|
|
126
|
-
image_stack_loader:
|
|
127
|
-
image_stack_loader_kwargs:
|
|
124
|
+
read_func: ReadFunc | None = None,
|
|
125
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
126
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
127
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
128
128
|
) -> CareamicsDataset[ImageStack]:
|
|
129
129
|
"""
|
|
130
130
|
Convenience function to create the CAREamicsDataset.
|
|
@@ -201,7 +201,7 @@ def create_array_dataset(
|
|
|
201
201
|
config: NGDataConfig,
|
|
202
202
|
mode: Mode,
|
|
203
203
|
inputs: Sequence[NDArray[Any]],
|
|
204
|
-
targets:
|
|
204
|
+
targets: Sequence[NDArray[Any]] | None,
|
|
205
205
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
206
206
|
"""
|
|
207
207
|
Create a CAREamicsDataset from array data.
|
|
@@ -223,7 +223,7 @@ def create_array_dataset(
|
|
|
223
223
|
A CAREamicsDataset.
|
|
224
224
|
"""
|
|
225
225
|
input_extractor = create_array_extractor(source=inputs, axes=config.axes)
|
|
226
|
-
target_extractor:
|
|
226
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
227
227
|
if targets is not None:
|
|
228
228
|
target_extractor = create_array_extractor(source=targets, axes=config.axes)
|
|
229
229
|
else:
|
|
@@ -235,7 +235,7 @@ def create_tiff_dataset(
|
|
|
235
235
|
config: NGDataConfig,
|
|
236
236
|
mode: Mode,
|
|
237
237
|
inputs: Sequence[Path],
|
|
238
|
-
targets:
|
|
238
|
+
targets: Sequence[Path] | None,
|
|
239
239
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
240
240
|
"""
|
|
241
241
|
Create a CAREamicsDataset from tiff files that will be all loaded into memory.
|
|
@@ -260,7 +260,7 @@ def create_tiff_dataset(
|
|
|
260
260
|
source=inputs,
|
|
261
261
|
axes=config.axes,
|
|
262
262
|
)
|
|
263
|
-
target_extractor:
|
|
263
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
264
264
|
if targets is not None:
|
|
265
265
|
target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
|
|
266
266
|
else:
|
|
@@ -273,7 +273,7 @@ def create_czi_dataset(
|
|
|
273
273
|
config: NGDataConfig,
|
|
274
274
|
mode: Mode,
|
|
275
275
|
inputs: Sequence[Path],
|
|
276
|
-
targets:
|
|
276
|
+
targets: Sequence[Path] | None,
|
|
277
277
|
) -> CareamicsDataset[CziImageStack]:
|
|
278
278
|
"""
|
|
279
279
|
Create a dataset from CZI files.
|
|
@@ -296,7 +296,7 @@ def create_czi_dataset(
|
|
|
296
296
|
"""
|
|
297
297
|
|
|
298
298
|
input_extractor = create_czi_extractor(source=inputs, axes=config.axes)
|
|
299
|
-
target_extractor:
|
|
299
|
+
target_extractor: PatchExtractor[CziImageStack] | None
|
|
300
300
|
if targets is not None:
|
|
301
301
|
target_extractor = create_czi_extractor(source=targets, axes=config.axes)
|
|
302
302
|
else:
|
|
@@ -309,7 +309,7 @@ def create_ome_zarr_dataset(
|
|
|
309
309
|
config: NGDataConfig,
|
|
310
310
|
mode: Mode,
|
|
311
311
|
inputs: Sequence[Path],
|
|
312
|
-
targets:
|
|
312
|
+
targets: Sequence[Path] | None,
|
|
313
313
|
) -> CareamicsDataset[ZarrImageStack]:
|
|
314
314
|
"""
|
|
315
315
|
Create a dataset from OME ZARR files.
|
|
@@ -332,7 +332,7 @@ def create_ome_zarr_dataset(
|
|
|
332
332
|
"""
|
|
333
333
|
|
|
334
334
|
input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
|
|
335
|
-
target_extractor:
|
|
335
|
+
target_extractor: PatchExtractor[ZarrImageStack] | None
|
|
336
336
|
if targets is not None:
|
|
337
337
|
target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
|
|
338
338
|
else:
|
|
@@ -345,7 +345,7 @@ def create_custom_file_dataset(
|
|
|
345
345
|
config: NGDataConfig,
|
|
346
346
|
mode: Mode,
|
|
347
347
|
inputs: Sequence[Path],
|
|
348
|
-
targets:
|
|
348
|
+
targets: Sequence[Path] | None,
|
|
349
349
|
*,
|
|
350
350
|
read_func: ReadFunc,
|
|
351
351
|
read_kwargs: dict[str, Any],
|
|
@@ -378,7 +378,7 @@ def create_custom_file_dataset(
|
|
|
378
378
|
input_extractor = create_custom_file_extractor(
|
|
379
379
|
source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
|
|
380
380
|
)
|
|
381
|
-
target_extractor:
|
|
381
|
+
target_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
382
382
|
if targets is not None:
|
|
383
383
|
target_extractor = create_custom_file_extractor(
|
|
384
384
|
source=targets,
|
|
@@ -396,7 +396,7 @@ def create_custom_image_stack_dataset(
|
|
|
396
396
|
config: NGDataConfig,
|
|
397
397
|
mode: Mode,
|
|
398
398
|
inputs: Any,
|
|
399
|
-
targets:
|
|
399
|
+
targets: Any | None,
|
|
400
400
|
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
401
401
|
*args: P.args,
|
|
402
402
|
**kwargs: P.kwargs,
|
|
@@ -436,7 +436,7 @@ def create_custom_image_stack_dataset(
|
|
|
436
436
|
*args,
|
|
437
437
|
**kwargs,
|
|
438
438
|
)
|
|
439
|
-
target_extractor:
|
|
439
|
+
target_extractor: PatchExtractor[GenericImageStack] | None
|
|
440
440
|
if targets is not None:
|
|
441
441
|
target_extractor = create_custom_image_stack_extractor(
|
|
442
442
|
targets,
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""A module for random patching strategies."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
@@ -31,7 +30,7 @@ class RandomPatchingStrategy:
|
|
|
31
30
|
self,
|
|
32
31
|
data_shapes: Sequence[Sequence[int]],
|
|
33
32
|
patch_size: Sequence[int],
|
|
34
|
-
seed:
|
|
33
|
+
seed: int | None = None,
|
|
35
34
|
):
|
|
36
35
|
"""
|
|
37
36
|
A patching strategy for sampling random patches.
|
|
@@ -193,7 +192,7 @@ class FixedRandomPatchingStrategy:
|
|
|
193
192
|
self,
|
|
194
193
|
data_shapes: Sequence[Sequence[int]],
|
|
195
194
|
patch_size: Sequence[int],
|
|
196
|
-
seed:
|
|
195
|
+
seed: int | None = None,
|
|
197
196
|
):
|
|
198
197
|
"""A patching strategy for sampling random patches.
|
|
199
198
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
from typing_extensions import ParamSpec
|
|
@@ -18,7 +17,7 @@ class SequentialPatchingStrategy:
|
|
|
18
17
|
self,
|
|
19
18
|
data_shapes: Sequence[Sequence[int]],
|
|
20
19
|
patch_size: Sequence[int],
|
|
21
|
-
overlaps:
|
|
20
|
+
overlaps: Sequence[int] | None = None,
|
|
22
21
|
):
|
|
23
22
|
self.data_shapes = data_shapes
|
|
24
23
|
self.patch_size = patch_size
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Sequence
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Union
|
|
8
8
|
|
|
9
9
|
from pytorch_lightning import LightningModule, Trainer
|
|
10
10
|
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
@@ -84,9 +84,9 @@ class PredictionWriterCallback(BasePredictionWriter):
|
|
|
84
84
|
cls,
|
|
85
85
|
write_type: SupportedWriteType,
|
|
86
86
|
tiled: bool,
|
|
87
|
-
write_func:
|
|
88
|
-
write_extension:
|
|
89
|
-
write_func_kwargs:
|
|
87
|
+
write_func: WriteFunc | None = None,
|
|
88
|
+
write_extension: str | None = None,
|
|
89
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
90
90
|
dirpath: Union[Path, str] = "predictions",
|
|
91
91
|
) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
|
|
92
92
|
"""
|
|
@@ -172,7 +172,7 @@ class PredictionWriterCallback(BasePredictionWriter):
|
|
|
172
172
|
trainer: Trainer,
|
|
173
173
|
pl_module: LightningModule,
|
|
174
174
|
prediction: Any, # TODO: change to expected type
|
|
175
|
-
batch_indices:
|
|
175
|
+
batch_indices: Sequence[int] | None,
|
|
176
176
|
batch: Any, # TODO: change to expected type
|
|
177
177
|
batch_idx: int,
|
|
178
178
|
dataloader_idx: int,
|