careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +83 -62
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -0
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +2 -0
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +1 -79
- careamics/config/configuration_model.py +12 -7
- careamics/config/data_model.py +29 -10
- careamics/config/inference_model.py +12 -2
- careamics/config/optimizer_models.py +6 -0
- careamics/config/support/supported_data.py +29 -4
- careamics/config/tile_information.py +10 -0
- careamics/config/training_model.py +5 -1
- careamics/dataset/dataset_utils/__init__.py +0 -6
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
- careamics/dataset/in_memory_dataset.py +37 -21
- careamics/dataset/iterable_dataset.py +38 -34
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/patching.py +53 -37
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +1 -1
- careamics/prediction_utils/__init__.py +0 -2
- careamics/prediction_utils/prediction_outputs.py +18 -46
- careamics/prediction_utils/stitch_prediction.py +17 -14
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
- careamics/config/configuration_example.py +0 -86
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/prediction_utils/create_pred_datamodule.py +0 -185
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,7 +9,7 @@ from typing import Literal, Union
|
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
12
|
-
from pydantic import BaseModel, ConfigDict,
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
|
13
13
|
from typing_extensions import Self
|
|
14
14
|
|
|
15
15
|
from .algorithm_model import AlgorithmConfig
|
|
@@ -147,20 +147,25 @@ class Configuration(BaseModel):
|
|
|
147
147
|
)
|
|
148
148
|
|
|
149
149
|
# version
|
|
150
|
-
version: Literal["0.1.0"] =
|
|
151
|
-
|
|
152
|
-
)
|
|
150
|
+
version: Literal["0.1.0"] = "0.1.0"
|
|
151
|
+
"""CAREamics configuration version."""
|
|
153
152
|
|
|
154
153
|
# required parameters
|
|
155
|
-
experiment_name: str
|
|
156
|
-
|
|
157
|
-
)
|
|
154
|
+
experiment_name: str
|
|
155
|
+
"""Name of the experiment, used to name logs and checkpoints."""
|
|
158
156
|
|
|
159
157
|
# Sub-configurations
|
|
160
158
|
algorithm_config: AlgorithmConfig
|
|
159
|
+
"""Algorithm configuration, holding all parameters required to configure the
|
|
160
|
+
model."""
|
|
161
161
|
|
|
162
162
|
data_config: DataConfig
|
|
163
|
+
"""Data configuration, holding all parameters required to configure the training
|
|
164
|
+
data loader."""
|
|
165
|
+
|
|
163
166
|
training_config: TrainingConfig
|
|
167
|
+
"""Training configuration, holding all parameters required to configure the
|
|
168
|
+
training process."""
|
|
164
169
|
|
|
165
170
|
@field_validator("experiment_name")
|
|
166
171
|
@classmethod
|
careamics/config/data_model.py
CHANGED
|
@@ -55,8 +55,8 @@ class DataConfig(BaseModel):
|
|
|
55
55
|
... axes="YX"
|
|
56
56
|
... )
|
|
57
57
|
|
|
58
|
-
To change the
|
|
59
|
-
>>> data.
|
|
58
|
+
To change the image_means and image_stds of the data:
|
|
59
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
60
60
|
|
|
61
61
|
One can pass also a list of transformations, by keyword, using the
|
|
62
62
|
SupportedTransform value:
|
|
@@ -80,22 +80,38 @@ class DataConfig(BaseModel):
|
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
# Dataset configuration
|
|
83
|
-
data_type: Literal["array", "tiff", "custom"]
|
|
83
|
+
data_type: Literal["array", "tiff", "custom"]
|
|
84
|
+
"""Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
|
|
85
|
+
in SupportedData."""
|
|
86
|
+
|
|
87
|
+
axes: str
|
|
88
|
+
"""Axes of the data, as defined in SupportedAxes."""
|
|
89
|
+
|
|
84
90
|
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
91
|
+
"""Patch size, as used during training."""
|
|
92
|
+
|
|
85
93
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
86
|
-
|
|
94
|
+
"""Batch size for training."""
|
|
87
95
|
|
|
88
96
|
# Optional fields
|
|
89
97
|
image_means: Optional[list[float]] = Field(
|
|
90
98
|
default=None, min_length=0, max_length=32
|
|
91
99
|
)
|
|
100
|
+
"""Means of the data across channels, used for normalization."""
|
|
101
|
+
|
|
92
102
|
image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
|
|
103
|
+
"""Standard deviations of the data across channels, used for normalization."""
|
|
104
|
+
|
|
93
105
|
target_means: Optional[list[float]] = Field(
|
|
94
106
|
default=None, min_length=0, max_length=32
|
|
95
107
|
)
|
|
108
|
+
"""Means of the target data across channels, used for normalization."""
|
|
109
|
+
|
|
96
110
|
target_stds: Optional[list[float]] = Field(
|
|
97
111
|
default=None, min_length=0, max_length=32
|
|
98
112
|
)
|
|
113
|
+
"""Standard deviations of the target data across channels, used for
|
|
114
|
+
normalization."""
|
|
99
115
|
|
|
100
116
|
transforms: list[TRANSFORMS_UNION] = Field(
|
|
101
117
|
default=[
|
|
@@ -111,8 +127,11 @@ class DataConfig(BaseModel):
|
|
|
111
127
|
],
|
|
112
128
|
validate_default=True,
|
|
113
129
|
)
|
|
130
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
131
|
+
in SupportedTransform. The default values are set for Noise2Void."""
|
|
114
132
|
|
|
115
133
|
dataloader_params: Optional[dict] = None
|
|
134
|
+
"""Dictionary of PyTorch dataloader parameters."""
|
|
116
135
|
|
|
117
136
|
@field_validator("patch_size")
|
|
118
137
|
@classmethod
|
|
@@ -346,7 +365,7 @@ class DataConfig(BaseModel):
|
|
|
346
365
|
if self.has_n2v_manipulate():
|
|
347
366
|
self.transforms.pop(-1)
|
|
348
367
|
|
|
349
|
-
def
|
|
368
|
+
def set_means_and_stds(
|
|
350
369
|
self,
|
|
351
370
|
image_means: Union[NDArray, tuple, list, None],
|
|
352
371
|
image_stds: Union[NDArray, tuple, list, None],
|
|
@@ -354,20 +373,20 @@ class DataConfig(BaseModel):
|
|
|
354
373
|
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
355
374
|
) -> None:
|
|
356
375
|
"""
|
|
357
|
-
Set mean and standard deviation of the data.
|
|
376
|
+
Set mean and standard deviation of the data across channels.
|
|
358
377
|
|
|
359
378
|
This method should be used instead setting the fields directly, as it would
|
|
360
379
|
otherwise trigger a validation error.
|
|
361
380
|
|
|
362
381
|
Parameters
|
|
363
382
|
----------
|
|
364
|
-
image_means :
|
|
383
|
+
image_means : numpy.ndarray ,tuple or list
|
|
365
384
|
Mean values for normalization.
|
|
366
|
-
image_stds :
|
|
385
|
+
image_stds : numpy.ndarray, tuple or list
|
|
367
386
|
Standard deviation values for normalization.
|
|
368
|
-
target_means :
|
|
387
|
+
target_means : numpy.ndarray, tuple or list, optional
|
|
369
388
|
Target mean values for normalization, by default ().
|
|
370
|
-
target_stds :
|
|
389
|
+
target_stds : numpy.ndarray, tuple or list, optional
|
|
371
390
|
Target standard deviation values for normalization, by default ().
|
|
372
391
|
"""
|
|
373
392
|
# make sure we pass a list
|
|
@@ -15,25 +15,35 @@ class InferenceConfig(BaseModel):
|
|
|
15
15
|
|
|
16
16
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
17
17
|
|
|
18
|
-
# Mandatory fields
|
|
19
18
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
19
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
|
|
20
|
+
|
|
20
21
|
tile_size: Optional[Union[list[int]]] = Field(
|
|
21
22
|
default=None, min_length=2, max_length=3
|
|
22
23
|
)
|
|
24
|
+
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
|
+
|
|
23
26
|
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
24
27
|
default=None, min_length=2, max_length=3
|
|
25
28
|
)
|
|
29
|
+
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
26
30
|
|
|
27
31
|
axes: str
|
|
32
|
+
"""Data axes (TSCZYX) in the order of the input data."""
|
|
28
33
|
|
|
29
34
|
image_means: list = Field(..., min_length=0, max_length=32)
|
|
35
|
+
"""Mean values for each input channel."""
|
|
36
|
+
|
|
30
37
|
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
38
|
+
"""Standard deviation values for each input channel."""
|
|
31
39
|
|
|
32
|
-
# only default TTAs are supported for now
|
|
40
|
+
# TODO only default TTAs are supported for now
|
|
33
41
|
tta_transforms: bool = Field(default=True)
|
|
42
|
+
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
|
|
34
43
|
|
|
35
44
|
# Dataloader parameters
|
|
36
45
|
batch_size: int = Field(default=1, ge=1)
|
|
46
|
+
"""Batch size for prediction."""
|
|
37
47
|
|
|
38
48
|
@field_validator("tile_overlap")
|
|
39
49
|
@classmethod
|
|
@@ -45,6 +45,7 @@ class OptimizerModel(BaseModel):
|
|
|
45
45
|
|
|
46
46
|
# Mandatory field
|
|
47
47
|
name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
|
|
48
|
+
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
48
49
|
|
|
49
50
|
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
50
51
|
parameters: dict = Field(
|
|
@@ -53,6 +54,7 @@ class OptimizerModel(BaseModel):
|
|
|
53
54
|
},
|
|
54
55
|
validate_default=True,
|
|
55
56
|
)
|
|
57
|
+
"""Parameters of the optimizer, see PyTorch documentation for more details."""
|
|
56
58
|
|
|
57
59
|
@field_validator("parameters")
|
|
58
60
|
@classmethod
|
|
@@ -140,9 +142,13 @@ class LrSchedulerModel(BaseModel):
|
|
|
140
142
|
|
|
141
143
|
# Mandatory field
|
|
142
144
|
name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
|
|
145
|
+
"""Name of the learning rate scheduler, supported schedulers are defined in
|
|
146
|
+
SupportedScheduler."""
|
|
143
147
|
|
|
144
148
|
# Optional parameters
|
|
145
149
|
parameters: dict = Field(default={}, validate_default=True)
|
|
150
|
+
"""Parameters of the learning rate scheduler, see PyTorch documentation for more
|
|
151
|
+
details."""
|
|
146
152
|
|
|
147
153
|
@field_validator("parameters")
|
|
148
154
|
@classmethod
|
|
@@ -60,9 +60,9 @@ class SupportedData(str, BaseEnum):
|
|
|
60
60
|
return super()._missing_(value)
|
|
61
61
|
|
|
62
62
|
@classmethod
|
|
63
|
-
def
|
|
63
|
+
def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
|
|
64
64
|
"""
|
|
65
|
-
Path.rglob and fnmatch compatible extension.
|
|
65
|
+
Get Path.rglob and fnmatch compatible extension.
|
|
66
66
|
|
|
67
67
|
Parameters
|
|
68
68
|
----------
|
|
@@ -72,13 +72,38 @@ class SupportedData(str, BaseEnum):
|
|
|
72
72
|
Returns
|
|
73
73
|
-------
|
|
74
74
|
str
|
|
75
|
-
Corresponding extension.
|
|
75
|
+
Corresponding extension pattern.
|
|
76
76
|
"""
|
|
77
77
|
if data_type == cls.ARRAY:
|
|
78
|
-
raise NotImplementedError(f"Data {data_type}
|
|
78
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
79
79
|
elif data_type == cls.TIFF:
|
|
80
80
|
return "*.tif*"
|
|
81
81
|
elif data_type == cls.CUSTOM:
|
|
82
82
|
return "*.*"
|
|
83
83
|
else:
|
|
84
84
|
raise ValueError(f"Data type {data_type} is not supported.")
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Get file extension of corresponding data type.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
data_type : str or SupportedData
|
|
94
|
+
Data type.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
str
|
|
99
|
+
Corresponding extension.
|
|
100
|
+
"""
|
|
101
|
+
if data_type == cls.ARRAY:
|
|
102
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
103
|
+
elif data_type == cls.TIFF:
|
|
104
|
+
return ".tiff"
|
|
105
|
+
elif data_type == cls.CUSTOM:
|
|
106
|
+
# TODO: improve this message
|
|
107
|
+
raise NotImplementedError("Custom extensions have to be passed elsewhere.")
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Data type {data_type} is not supported.")
|
|
@@ -19,10 +19,20 @@ class TileInformation(BaseModel):
|
|
|
19
19
|
model_config = ConfigDict(validate_default=True)
|
|
20
20
|
|
|
21
21
|
array_shape: tuple[int, ...]
|
|
22
|
+
"""Shape of the original (untiled) array."""
|
|
23
|
+
|
|
22
24
|
last_tile: bool = False
|
|
25
|
+
"""Whether this tile is the last one of the array."""
|
|
26
|
+
|
|
23
27
|
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
28
|
+
"""Inner coordinates of the tile where to crop the prediction in order to stitch
|
|
29
|
+
it back into the original image."""
|
|
30
|
+
|
|
24
31
|
stitch_coords: tuple[tuple[int, ...], ...]
|
|
32
|
+
"""Coordinates in the original image where to stitch the cropped tile back."""
|
|
33
|
+
|
|
25
34
|
sample_id: int
|
|
35
|
+
"""Sample ID of the tile."""
|
|
26
36
|
|
|
27
37
|
@field_validator("array_shape")
|
|
28
38
|
@classmethod
|
|
@@ -35,15 +35,19 @@ class TrainingConfig(BaseModel):
|
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
num_epochs: int = Field(default=20, ge=1)
|
|
38
|
+
"""Number of epochs, greater than 0."""
|
|
38
39
|
|
|
39
40
|
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
41
|
+
"""Logger to use during training. If None, no logger will be used. Available
|
|
42
|
+
loggers are defined in SupportedLogger."""
|
|
40
43
|
|
|
41
44
|
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
45
|
+
"""Checkpoint callback configuration."""
|
|
42
46
|
|
|
43
47
|
early_stopping_callback: Optional[EarlyStoppingModel] = Field(
|
|
44
48
|
default=None, validate_default=True
|
|
45
49
|
)
|
|
46
|
-
|
|
50
|
+
"""Early stopping callback configuration."""
|
|
47
51
|
|
|
48
52
|
def __str__(self) -> str:
|
|
49
53
|
"""Pretty string reprensenting the configuration.
|
|
@@ -6,9 +6,6 @@ __all__ = [
|
|
|
6
6
|
"get_files_size",
|
|
7
7
|
"list_files",
|
|
8
8
|
"validate_source_target_files",
|
|
9
|
-
"read_tiff",
|
|
10
|
-
"get_read_func",
|
|
11
|
-
"read_zarr",
|
|
12
9
|
"iterate_over_files",
|
|
13
10
|
"WelfordStatistics",
|
|
14
11
|
]
|
|
@@ -19,7 +16,4 @@ from .dataset_utils import (
|
|
|
19
16
|
)
|
|
20
17
|
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
21
18
|
from .iterate_over_files import iterate_over_files
|
|
22
|
-
from .read_tiff import read_tiff
|
|
23
|
-
from .read_utils import get_read_func
|
|
24
|
-
from .read_zarr import read_zarr
|
|
25
19
|
from .running_stats import WelfordStatistics, compute_normalization_stats
|
|
@@ -75,7 +75,7 @@ def list_files(
|
|
|
75
75
|
raise FileNotFoundError(f"Data path {data_path} does not exist.")
|
|
76
76
|
|
|
77
77
|
# get extension compatible with fnmatch and rglob search
|
|
78
|
-
extension = SupportedData.
|
|
78
|
+
extension = SupportedData.get_extension_pattern(data_type)
|
|
79
79
|
|
|
80
80
|
if data_type == SupportedData.CUSTOM and extension_filter != "":
|
|
81
81
|
extension = extension_filter
|
|
@@ -9,10 +9,10 @@ from numpy.typing import NDArray
|
|
|
9
9
|
from torch.utils.data import get_worker_info
|
|
10
10
|
|
|
11
11
|
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
12
13
|
from careamics.utils.logging import get_logger
|
|
13
14
|
|
|
14
15
|
from .dataset_utils import reshape_array
|
|
15
|
-
from .read_tiff import read_tiff
|
|
16
16
|
|
|
17
17
|
logger = get_logger(__name__)
|
|
18
18
|
|
|
@@ -9,14 +9,15 @@ from typing import Any, Callable, Optional, Union
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
12
13
|
from careamics.transforms import Compose
|
|
13
14
|
|
|
14
15
|
from ..config import DataConfig
|
|
15
16
|
from ..config.transformations import NormalizeModel
|
|
16
17
|
from ..utils.logging import get_logger
|
|
17
|
-
from .dataset_utils import read_tiff
|
|
18
18
|
from .patching.patching import (
|
|
19
19
|
PatchedOutput,
|
|
20
|
+
Stats,
|
|
20
21
|
prepare_patches_supervised,
|
|
21
22
|
prepare_patches_supervised_array,
|
|
22
23
|
prepare_patches_unsupervised,
|
|
@@ -77,47 +78,50 @@ class InMemoryDataset(Dataset):
|
|
|
77
78
|
# read function
|
|
78
79
|
self.read_source_func = read_source_func
|
|
79
80
|
|
|
80
|
-
#
|
|
81
|
+
# generate patches
|
|
81
82
|
supervised = self.input_targets is not None
|
|
82
83
|
patches_data = self._prepare_patches(supervised)
|
|
83
84
|
|
|
84
|
-
#
|
|
85
|
+
# unpack the dataclass
|
|
85
86
|
self.data = patches_data.patches
|
|
86
87
|
self.data_targets = patches_data.targets
|
|
87
88
|
|
|
89
|
+
# set image statistics
|
|
88
90
|
if self.data_config.image_means is None:
|
|
89
|
-
self.
|
|
90
|
-
self.image_stds = patches_data.image_stats.stds
|
|
91
|
+
self.image_stats = patches_data.image_stats
|
|
91
92
|
logger.info(
|
|
92
|
-
f"Computed dataset mean: {self.
|
|
93
|
+
f"Computed dataset mean: {self.image_stats.means}, "
|
|
94
|
+
f"std: {self.image_stats.stds}"
|
|
93
95
|
)
|
|
94
96
|
else:
|
|
95
|
-
self.
|
|
96
|
-
|
|
97
|
+
self.image_stats = Stats(
|
|
98
|
+
self.data_config.image_means, self.data_config.image_stds
|
|
99
|
+
)
|
|
97
100
|
|
|
101
|
+
# set target statistics
|
|
98
102
|
if self.data_config.target_means is None:
|
|
99
|
-
self.
|
|
100
|
-
self.target_stds = patches_data.target_stats.stds
|
|
103
|
+
self.target_stats = patches_data.target_stats
|
|
101
104
|
else:
|
|
102
|
-
self.
|
|
103
|
-
|
|
105
|
+
self.target_stats = Stats(
|
|
106
|
+
self.data_config.target_means, self.data_config.target_stds
|
|
107
|
+
)
|
|
104
108
|
|
|
105
109
|
# update mean and std in configuration
|
|
106
110
|
# the object is mutable and should then be recorded in the CAREamist obj
|
|
107
|
-
self.data_config.
|
|
108
|
-
image_means=self.
|
|
109
|
-
image_stds=self.
|
|
110
|
-
target_means=self.
|
|
111
|
-
target_stds=self.
|
|
111
|
+
self.data_config.set_means_and_stds(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
112
116
|
)
|
|
113
117
|
# get transforms
|
|
114
118
|
self.patch_transform = Compose(
|
|
115
119
|
transform_list=[
|
|
116
120
|
NormalizeModel(
|
|
117
|
-
image_means=self.
|
|
118
|
-
image_stds=self.
|
|
119
|
-
target_means=self.
|
|
120
|
-
target_stds=self.
|
|
121
|
+
image_means=self.image_stats.means,
|
|
122
|
+
image_stds=self.image_stats.stds,
|
|
123
|
+
target_means=self.target_stats.means,
|
|
124
|
+
target_stds=self.target_stats.stds,
|
|
121
125
|
)
|
|
122
126
|
]
|
|
123
127
|
+ self.data_config.transforms,
|
|
@@ -223,6 +227,18 @@ class InMemoryDataset(Dataset):
|
|
|
223
227
|
"and no N2V manipulation (no N2V training)."
|
|
224
228
|
)
|
|
225
229
|
|
|
230
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
231
|
+
"""Return training data statistics.
|
|
232
|
+
|
|
233
|
+
This does not return the target data statistics, only those of the input.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
tuple of list of floats
|
|
238
|
+
Means and standard deviations across channels of the training data.
|
|
239
|
+
"""
|
|
240
|
+
return self.image_stats.get_statistics()
|
|
241
|
+
|
|
226
242
|
def split_dataset(
|
|
227
243
|
self,
|
|
228
244
|
percentage: float = 0.1,
|
|
@@ -12,15 +12,13 @@ from torch.utils.data import IterableDataset
|
|
|
12
12
|
|
|
13
13
|
from careamics.config import DataConfig
|
|
14
14
|
from careamics.config.transformations import NormalizeModel
|
|
15
|
+
from careamics.file_io.read import read_tiff
|
|
15
16
|
from careamics.transforms import Compose
|
|
16
17
|
|
|
17
18
|
from ..utils.logging import get_logger
|
|
18
|
-
from .dataset_utils import
|
|
19
|
-
iterate_over_files,
|
|
20
|
-
read_tiff,
|
|
21
|
-
)
|
|
19
|
+
from .dataset_utils import iterate_over_files
|
|
22
20
|
from .dataset_utils.running_stats import WelfordStatistics
|
|
23
|
-
from .patching.patching import Stats
|
|
21
|
+
from .patching.patching import Stats
|
|
24
22
|
from .patching.random_patching import extract_patches_random
|
|
25
23
|
|
|
26
24
|
logger = get_logger(__name__)
|
|
@@ -78,31 +76,31 @@ class PathIterableDataset(IterableDataset):
|
|
|
78
76
|
# only checking the image_mean because the DataConfig class ensures that
|
|
79
77
|
# if image_mean is provided, image_std is also provided
|
|
80
78
|
if not self.data_config.image_means:
|
|
81
|
-
self.
|
|
79
|
+
self.image_stats, self.target_stats = self._calculate_mean_and_std()
|
|
82
80
|
logger.info(
|
|
83
|
-
f"Computed dataset mean: {self.
|
|
84
|
-
f"std: {self.
|
|
81
|
+
f"Computed dataset mean: {self.image_stats.means},"
|
|
82
|
+
f"std: {self.image_stats.stds}"
|
|
85
83
|
)
|
|
86
84
|
|
|
87
85
|
# update the mean in the config
|
|
88
|
-
self.data_config.
|
|
89
|
-
image_means=self.
|
|
90
|
-
image_stds=self.
|
|
86
|
+
self.data_config.set_means_and_stds(
|
|
87
|
+
image_means=self.image_stats.means,
|
|
88
|
+
image_stds=self.image_stats.stds,
|
|
91
89
|
target_means=(
|
|
92
|
-
list(self.
|
|
93
|
-
if self.
|
|
90
|
+
list(self.target_stats.means)
|
|
91
|
+
if self.target_stats.means is not None
|
|
94
92
|
else None
|
|
95
93
|
),
|
|
96
94
|
target_stds=(
|
|
97
|
-
list(self.
|
|
98
|
-
if self.
|
|
95
|
+
list(self.target_stats.stds)
|
|
96
|
+
if self.target_stats.stds is not None
|
|
99
97
|
else None
|
|
100
98
|
),
|
|
101
99
|
)
|
|
102
100
|
|
|
103
101
|
else:
|
|
104
102
|
# if mean and std are provided in the config, use them
|
|
105
|
-
self.
|
|
103
|
+
self.image_stats, self.target_stats = (
|
|
106
104
|
Stats(self.data_config.image_means, self.data_config.image_stds),
|
|
107
105
|
Stats(self.data_config.target_means, self.data_config.target_stds),
|
|
108
106
|
)
|
|
@@ -111,23 +109,23 @@ class PathIterableDataset(IterableDataset):
|
|
|
111
109
|
self.patch_transform = Compose(
|
|
112
110
|
transform_list=[
|
|
113
111
|
NormalizeModel(
|
|
114
|
-
image_means=self.
|
|
115
|
-
image_stds=self.
|
|
116
|
-
target_means=self.
|
|
117
|
-
target_stds=self.
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
118
116
|
)
|
|
119
117
|
]
|
|
120
118
|
+ data_config.transforms
|
|
121
119
|
)
|
|
122
120
|
|
|
123
|
-
def _calculate_mean_and_std(self) ->
|
|
121
|
+
def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
|
|
124
122
|
"""
|
|
125
123
|
Calculate mean and std of the dataset.
|
|
126
124
|
|
|
127
125
|
Returns
|
|
128
126
|
-------
|
|
129
|
-
|
|
130
|
-
Data
|
|
127
|
+
tuple of Stats and optional Stats
|
|
128
|
+
Data classes containing the image and target statistics.
|
|
131
129
|
"""
|
|
132
130
|
num_samples = 0
|
|
133
131
|
image_stats = WelfordStatistics()
|
|
@@ -155,15 +153,12 @@ class PathIterableDataset(IterableDataset):
|
|
|
155
153
|
if target is not None:
|
|
156
154
|
target_means, target_stds = target_stats.finalize()
|
|
157
155
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
np.array(target_stds) if target is not None else None,
|
|
165
|
-
),
|
|
166
|
-
)
|
|
156
|
+
return (
|
|
157
|
+
Stats(image_means, image_stds),
|
|
158
|
+
Stats(np.array(target_means), np.array(target_stds)),
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
return Stats(image_means, image_stds), Stats(None, None)
|
|
167
162
|
|
|
168
163
|
def __iter__(
|
|
169
164
|
self,
|
|
@@ -177,8 +172,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
177
172
|
Single patch.
|
|
178
173
|
"""
|
|
179
174
|
assert (
|
|
180
|
-
self.
|
|
181
|
-
and self.data_stats.image_stats.stds is not None
|
|
175
|
+
self.image_stats.means is not None and self.image_stats.stds is not None
|
|
182
176
|
), "Mean and std must be provided"
|
|
183
177
|
|
|
184
178
|
# iterate over files
|
|
@@ -201,6 +195,16 @@ class PathIterableDataset(IterableDataset):
|
|
|
201
195
|
target=patch_data[1],
|
|
202
196
|
)
|
|
203
197
|
|
|
198
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
199
|
+
"""Return training data statistics.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
tuple of list of floats
|
|
204
|
+
Means and standard deviations across channels of the training data.
|
|
205
|
+
"""
|
|
206
|
+
return self.image_stats.get_statistics()
|
|
207
|
+
|
|
204
208
|
def get_number_of_files(self) -> int:
|
|
205
209
|
"""
|
|
206
210
|
Return the number of files in the dataset.
|
|
@@ -8,11 +8,12 @@ from typing import Any, Callable, Generator
|
|
|
8
8
|
from numpy.typing import NDArray
|
|
9
9
|
from torch.utils.data import IterableDataset
|
|
10
10
|
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
11
12
|
from careamics.transforms import Compose
|
|
12
13
|
|
|
13
14
|
from ..config import InferenceConfig
|
|
14
15
|
from ..config.transformations import NormalizeModel
|
|
15
|
-
from .dataset_utils import iterate_over_files
|
|
16
|
+
from .dataset_utils import iterate_over_files
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class IterablePredDataset(IterableDataset):
|
|
@@ -8,12 +8,13 @@ from typing import Any, Callable, Generator
|
|
|
8
8
|
from numpy.typing import NDArray
|
|
9
9
|
from torch.utils.data import IterableDataset
|
|
10
10
|
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
11
12
|
from careamics.transforms import Compose
|
|
12
13
|
|
|
13
14
|
from ..config import InferenceConfig
|
|
14
15
|
from ..config.tile_information import TileInformation
|
|
15
16
|
from ..config.transformations import NormalizeModel
|
|
16
|
-
from .dataset_utils import iterate_over_files
|
|
17
|
+
from .dataset_utils import iterate_over_files
|
|
17
18
|
from .tiling import extract_tiles
|
|
18
19
|
|
|
19
20
|
|