careamics 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +0 -4
- careamics/careamist.py +0 -1
- careamics/config/__init__.py +1 -13
- careamics/config/algorithms/care_algorithm_model.py +84 -0
- careamics/config/algorithms/n2n_algorithm_model.py +85 -0
- careamics/config/algorithms/n2v_algorithm_model.py +269 -1
- careamics/config/configuration.py +21 -13
- careamics/config/configuration_factories.py +179 -187
- careamics/config/configuration_io.py +2 -2
- careamics/config/data/__init__.py +1 -4
- careamics/config/data/data_model.py +46 -62
- careamics/config/support/supported_transforms.py +1 -1
- careamics/config/transformations/__init__.py +0 -2
- careamics/config/transformations/n2v_manipulate_model.py +15 -0
- careamics/config/transformations/transform_unions.py +0 -13
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +3 -10
- careamics/dataset/in_memory_pred_dataset.py +3 -5
- careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +3 -5
- careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
- careamics/dataset_ng/dataset/__init__.py +3 -0
- careamics/dataset_ng/dataset/dataset.py +184 -0
- careamics/dataset_ng/demo_dataset.ipynb +271 -0
- careamics/dataset_ng/demo_patch_extractor.py +53 -0
- careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
- careamics/dataset_ng/patch_extractor/__init__.py +10 -0
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
- careamics/dataset_ng/patching_strategies/__init__.py +11 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
- careamics/lightning/lightning_module.py +78 -27
- careamics/lightning/train_data_module.py +8 -39
- careamics/losses/fcn/losses.py +17 -10
- careamics/model_io/bioimage/bioimage_utils.py +5 -3
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +2 -2
- careamics/transforms/__init__.py +2 -1
- careamics/transforms/compose.py +5 -15
- careamics/transforms/n2v_manipulate_torch.py +143 -0
- careamics/transforms/pixel_manipulation.py +1 -0
- careamics/transforms/pixel_manipulation_torch.py +418 -0
- careamics/utils/version.py +38 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/RECORD +58 -41
- careamics/config/care_configuration.py +0 -100
- careamics/config/data/n2v_data_model.py +0 -193
- careamics/config/n2n_configuration.py +0 -101
- careamics/config/n2v_configuration.py +0 -266
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""CAREamics transformation Pydantic models."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"N2V_TRANSFORMS_UNION",
|
|
5
4
|
"NORM_AND_SPATIAL_UNION",
|
|
6
5
|
"SPATIAL_TRANSFORMS_UNION",
|
|
7
6
|
"N2VManipulateModel",
|
|
@@ -16,7 +15,6 @@ from .n2v_manipulate_model import N2VManipulateModel
|
|
|
16
15
|
from .normalize_model import NormalizeModel
|
|
17
16
|
from .transform_model import TransformModel
|
|
18
17
|
from .transform_unions import (
|
|
19
|
-
N2V_TRANSFORMS_UNION,
|
|
20
18
|
NORM_AND_SPATIAL_UNION,
|
|
21
19
|
SPATIAL_TRANSFORMS_UNION,
|
|
22
20
|
)
|
|
@@ -7,6 +7,8 @@ from pydantic import ConfigDict, Field, field_validator
|
|
|
7
7
|
from .transform_model import TransformModel
|
|
8
8
|
|
|
9
9
|
|
|
10
|
+
# TODO should probably not be a TransformModel anymore, no reason for it
|
|
11
|
+
# `name` is used as a discriminator field in the transforms
|
|
10
12
|
class N2VManipulateModel(TransformModel):
|
|
11
13
|
"""
|
|
12
14
|
Pydantic model used to represent N2V manipulation.
|
|
@@ -32,11 +34,24 @@ class N2VManipulateModel(TransformModel):
|
|
|
32
34
|
)
|
|
33
35
|
|
|
34
36
|
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
37
|
+
|
|
35
38
|
roi_size: int = Field(default=11, ge=3, le=21)
|
|
39
|
+
"""Size of the region where the pixel manipulation is applied."""
|
|
40
|
+
|
|
36
41
|
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
|
|
42
|
+
"""Percentage of masked pixels per image."""
|
|
43
|
+
|
|
44
|
+
remove_center: bool = Field(default=True) # TODO remove it
|
|
45
|
+
"""Exclude center pixel from average calculation.""" # TODO rephrase this
|
|
46
|
+
|
|
37
47
|
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
48
|
+
"""Strategy for pixel value replacement."""
|
|
49
|
+
|
|
38
50
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
51
|
+
"""Orientation of the structN2V mask. Set to `\"non\"` to not apply StructN2V."""
|
|
52
|
+
|
|
39
53
|
struct_mask_span: int = Field(default=5, ge=3, le=15)
|
|
54
|
+
"""Size of the structN2V mask."""
|
|
40
55
|
|
|
41
56
|
@field_validator("roi_size", "struct_mask_span")
|
|
42
57
|
@classmethod
|
|
@@ -4,7 +4,6 @@ from typing import Annotated, Union
|
|
|
4
4
|
|
|
5
5
|
from pydantic import Discriminator
|
|
6
6
|
|
|
7
|
-
from .n2v_manipulate_model import N2VManipulateModel
|
|
8
7
|
from .normalize_model import NormalizeModel
|
|
9
8
|
from .xy_flip_model import XYFlipModel
|
|
10
9
|
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -14,7 +13,6 @@ NORM_AND_SPATIAL_UNION = Annotated[
|
|
|
14
13
|
NormalizeModel,
|
|
15
14
|
XYFlipModel,
|
|
16
15
|
XYRandomRotate90Model,
|
|
17
|
-
N2VManipulateModel,
|
|
18
16
|
],
|
|
19
17
|
Discriminator("name"), # used to tell the different transform models apart
|
|
20
18
|
]
|
|
@@ -29,14 +27,3 @@ SPATIAL_TRANSFORMS_UNION = Annotated[
|
|
|
29
27
|
Discriminator("name"), # used to tell the different transform models apart
|
|
30
28
|
]
|
|
31
29
|
"""Available spatial transforms in CAREamics."""
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
N2V_TRANSFORMS_UNION = Annotated[
|
|
35
|
-
Union[
|
|
36
|
-
XYFlipModel,
|
|
37
|
-
XYRandomRotate90Model,
|
|
38
|
-
N2VManipulateModel,
|
|
39
|
-
],
|
|
40
|
-
Discriminator("name"), # used to tell the different transform models apart
|
|
41
|
-
]
|
|
42
|
-
"""Available N2V-compatible transforms in CAREamics."""
|
|
@@ -9,7 +9,7 @@ from typing import Callable, Optional, Union
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import get_worker_info
|
|
11
11
|
|
|
12
|
-
from careamics.config import
|
|
12
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
13
13
|
from careamics.file_io.read import read_tiff
|
|
14
14
|
from careamics.utils.logging import get_logger
|
|
15
15
|
|
|
@@ -19,7 +19,7 @@ logger = get_logger(__name__)
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def iterate_over_files(
|
|
22
|
-
data_config: Union[
|
|
22
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
23
23
|
data_files: list[Path],
|
|
24
24
|
target_files: Optional[list[Path]] = None,
|
|
25
25
|
read_source_func: Callable = read_tiff,
|
|
@@ -9,7 +9,7 @@ from typing import Any, Callable, Optional, Union
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
|
-
from careamics.config import
|
|
12
|
+
from careamics.config import DataConfig
|
|
13
13
|
from careamics.config.transformations import NormalizeModel
|
|
14
14
|
from careamics.dataset.patching.patching import (
|
|
15
15
|
PatchedOutput,
|
|
@@ -46,7 +46,7 @@ class InMemoryDataset(Dataset):
|
|
|
46
46
|
|
|
47
47
|
def __init__(
|
|
48
48
|
self,
|
|
49
|
-
data_config:
|
|
49
|
+
data_config: DataConfig,
|
|
50
50
|
inputs: Union[np.ndarray, list[Path]],
|
|
51
51
|
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
52
52
|
read_source_func: Callable = read_tiff,
|
|
@@ -215,16 +215,9 @@ class InMemoryDataset(Dataset):
|
|
|
215
215
|
if self.data_targets is not None:
|
|
216
216
|
# get target
|
|
217
217
|
target = self.data_targets[index]
|
|
218
|
-
|
|
219
218
|
return self.patch_transform(patch=patch, target=target)
|
|
220
219
|
|
|
221
|
-
|
|
222
|
-
return self.patch_transform(patch=patch)
|
|
223
|
-
else:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"Something went wrong! No target provided (not supervised training) "
|
|
226
|
-
"while the algorithm is not Noise2Void."
|
|
227
|
-
)
|
|
220
|
+
return self.patch_transform(patch=patch)
|
|
228
221
|
|
|
229
222
|
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
230
223
|
"""Return training data statistics.
|
|
@@ -69,7 +69,7 @@ class InMemoryPredDataset(Dataset):
|
|
|
69
69
|
"""
|
|
70
70
|
return len(self.data)
|
|
71
71
|
|
|
72
|
-
def __getitem__(self, index: int) -> NDArray:
|
|
72
|
+
def __getitem__(self, index: int) -> tuple[NDArray, ...]:
|
|
73
73
|
"""
|
|
74
74
|
Return the patch corresponding to the provided index.
|
|
75
75
|
|
|
@@ -80,9 +80,7 @@ class InMemoryPredDataset(Dataset):
|
|
|
80
80
|
|
|
81
81
|
Returns
|
|
82
82
|
-------
|
|
83
|
-
|
|
83
|
+
tuple(numpy.ndarray, ...)
|
|
84
84
|
Transformed patch.
|
|
85
85
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
return transformed_patch
|
|
86
|
+
return self.patch_transform(patch=self.data[index])
|
|
@@ -107,7 +107,7 @@ class InMemoryTiledPredDataset(Dataset):
|
|
|
107
107
|
"""
|
|
108
108
|
return len(self.data)
|
|
109
109
|
|
|
110
|
-
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[tuple[NDArray, ...], TileInformation]:
|
|
111
111
|
"""
|
|
112
112
|
Return the patch corresponding to the provided index.
|
|
113
113
|
|
|
@@ -124,6 +124,6 @@ class InMemoryTiledPredDataset(Dataset):
|
|
|
124
124
|
tile_array, tile_info = self.data[index]
|
|
125
125
|
|
|
126
126
|
# Apply transforms
|
|
127
|
-
transformed_tile
|
|
127
|
+
transformed_tile = self.patch_transform(patch=tile_array)
|
|
128
128
|
|
|
129
129
|
return transformed_tile, tile_info
|
|
@@ -10,7 +10,7 @@ from typing import Callable, Optional
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
from torch.utils.data import IterableDataset
|
|
12
12
|
|
|
13
|
-
from careamics.config import
|
|
13
|
+
from careamics.config import DataConfig
|
|
14
14
|
from careamics.config.transformations import NormalizeModel
|
|
15
15
|
from careamics.file_io.read import read_tiff
|
|
16
16
|
from careamics.transforms import Compose
|
|
@@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
49
49
|
|
|
50
50
|
def __init__(
|
|
51
51
|
self,
|
|
52
|
-
data_config:
|
|
52
|
+
data_config: DataConfig,
|
|
53
53
|
src_files: list[Path],
|
|
54
54
|
target_files: Optional[list[Path]] = None,
|
|
55
55
|
read_source_func: Callable = read_tiff,
|
|
@@ -97,13 +97,13 @@ class IterablePredDataset(IterableDataset):
|
|
|
97
97
|
|
|
98
98
|
def __iter__(
|
|
99
99
|
self,
|
|
100
|
-
) -> Generator[NDArray, None, None]:
|
|
100
|
+
) -> Generator[tuple[NDArray, ...], None, None]:
|
|
101
101
|
"""
|
|
102
102
|
Iterate over data source and yield single patch.
|
|
103
103
|
|
|
104
104
|
Yields
|
|
105
105
|
------
|
|
106
|
-
|
|
106
|
+
(numpy.ndarray, numpy.ndarray or None)
|
|
107
107
|
Single patch.
|
|
108
108
|
"""
|
|
109
109
|
assert (
|
|
@@ -118,6 +118,4 @@ class IterablePredDataset(IterableDataset):
|
|
|
118
118
|
# sample has S dimension
|
|
119
119
|
for i in range(sample.shape[0]):
|
|
120
120
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
yield transformed_sample
|
|
121
|
+
yield self.patch_transform(patch=sample[i])
|
|
@@ -109,13 +109,13 @@ class IterableTiledPredDataset(IterableDataset):
|
|
|
109
109
|
|
|
110
110
|
def __iter__(
|
|
111
111
|
self,
|
|
112
|
-
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
112
|
+
) -> Generator[tuple[tuple[NDArray, ...], TileInformation], None, None]:
|
|
113
113
|
"""
|
|
114
114
|
Iterate over data source and yield single patch.
|
|
115
115
|
|
|
116
116
|
Yields
|
|
117
117
|
------
|
|
118
|
-
Generator of
|
|
118
|
+
Generator of (np.ndarray, np.ndarray or None) and TileInformation tuple
|
|
119
119
|
Generator of single tiles.
|
|
120
120
|
"""
|
|
121
121
|
assert (
|
|
@@ -136,6 +136,6 @@ class IterableTiledPredDataset(IterableDataset):
|
|
|
136
136
|
|
|
137
137
|
# apply transform to patches
|
|
138
138
|
for patch_array, tile_info in patch_gen:
|
|
139
|
-
transformed_patch
|
|
139
|
+
transformed_patch = self.patch_transform(patch=patch_array)
|
|
140
140
|
|
|
141
141
|
yield transformed_patch, tile_info
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, NamedTuple, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
from typing_extensions import ParamSpec
|
|
10
|
+
|
|
11
|
+
from careamics.config import DataConfig, InferenceConfig
|
|
12
|
+
from careamics.config.support import SupportedData
|
|
13
|
+
from careamics.dataset.patching.patching import Stats
|
|
14
|
+
from careamics.dataset_ng.patch_extractor import (
|
|
15
|
+
ImageStackLoader,
|
|
16
|
+
PatchExtractor,
|
|
17
|
+
create_patch_extractor,
|
|
18
|
+
)
|
|
19
|
+
from careamics.dataset_ng.patching_strategies import (
|
|
20
|
+
FixedRandomPatchingStrategy,
|
|
21
|
+
PatchingStrategy,
|
|
22
|
+
PatchSpecs,
|
|
23
|
+
RandomPatchingStrategy,
|
|
24
|
+
)
|
|
25
|
+
from careamics.transforms import Compose
|
|
26
|
+
|
|
27
|
+
P = ParamSpec("P")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Mode(str, Enum):
|
|
31
|
+
TRAINING = "training"
|
|
32
|
+
VALIDATING = "validating"
|
|
33
|
+
PREDICTING = "predicting"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ImageRegionData(NamedTuple):
|
|
37
|
+
data: NDArray
|
|
38
|
+
source: Union[Path, Literal["array"]]
|
|
39
|
+
data_shape: Sequence[int]
|
|
40
|
+
dtype: str # dtype should be str for collate
|
|
41
|
+
axes: str
|
|
42
|
+
region_spec: PatchSpecs
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
InputType = Union[Sequence[np.ndarray], Sequence[Path]]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CareamicsDataset(Dataset):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
52
|
+
mode: Mode,
|
|
53
|
+
inputs: InputType,
|
|
54
|
+
targets: Optional[InputType] = None,
|
|
55
|
+
image_stack_loader: Optional[ImageStackLoader[P]] = None,
|
|
56
|
+
*args: P.args,
|
|
57
|
+
**kwargs: P.kwargs,
|
|
58
|
+
):
|
|
59
|
+
self.config = data_config
|
|
60
|
+
self.mode = mode
|
|
61
|
+
|
|
62
|
+
data_type_enum = SupportedData(self.config.data_type)
|
|
63
|
+
self.input_extractor = create_patch_extractor(
|
|
64
|
+
inputs,
|
|
65
|
+
self.config.axes,
|
|
66
|
+
data_type_enum,
|
|
67
|
+
image_stack_loader,
|
|
68
|
+
*args,
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if targets is not None:
|
|
72
|
+
self.target_extractor: Optional[PatchExtractor] = create_patch_extractor(
|
|
73
|
+
targets,
|
|
74
|
+
self.config.axes,
|
|
75
|
+
data_type_enum,
|
|
76
|
+
image_stack_loader,
|
|
77
|
+
*args,
|
|
78
|
+
**kwargs,
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
self.target_extractor = None
|
|
82
|
+
|
|
83
|
+
self.patching_strategy = self._initialize_patching_strategy()
|
|
84
|
+
|
|
85
|
+
self.input_stats, self.target_stats = self._initialize_statistics()
|
|
86
|
+
|
|
87
|
+
self.transforms = self._initialize_transforms()
|
|
88
|
+
|
|
89
|
+
def _initialize_patching_strategy(self) -> PatchingStrategy:
|
|
90
|
+
patching_strategy: PatchingStrategy
|
|
91
|
+
if self.mode == Mode.TRAINING:
|
|
92
|
+
if isinstance(self.config, InferenceConfig):
|
|
93
|
+
raise ValueError("Inference config cannot be used for training.")
|
|
94
|
+
patching_strategy = RandomPatchingStrategy(
|
|
95
|
+
data_shapes=self.input_extractor.shape,
|
|
96
|
+
patch_size=self.config.patch_size,
|
|
97
|
+
# TODO: Add random seed to dataconfig
|
|
98
|
+
seed=getattr(self.config, "random_seed", None),
|
|
99
|
+
)
|
|
100
|
+
elif self.mode == Mode.VALIDATING:
|
|
101
|
+
if isinstance(self.config, InferenceConfig):
|
|
102
|
+
raise ValueError("Inference config cannot be used for validating.")
|
|
103
|
+
patching_strategy = FixedRandomPatchingStrategy(
|
|
104
|
+
data_shapes=self.input_extractor.shape,
|
|
105
|
+
patch_size=self.config.patch_size,
|
|
106
|
+
# TODO: Add random seed to dataconfig
|
|
107
|
+
seed=getattr(self.config, "random_seed", None),
|
|
108
|
+
)
|
|
109
|
+
elif self.mode == Mode.PREDICTING:
|
|
110
|
+
# TODO: patching strategy will be tilingStrategy in upcoming PR
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
"Prediction mode for the CAREamicsDataset has not been implemented yet."
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unrecognised dataset mode {self.mode}.")
|
|
116
|
+
|
|
117
|
+
return patching_strategy
|
|
118
|
+
|
|
119
|
+
def _initialize_transforms(self) -> Optional[Compose]:
|
|
120
|
+
if isinstance(self.config, DataConfig):
|
|
121
|
+
return Compose(
|
|
122
|
+
transform_list=list(self.config.transforms),
|
|
123
|
+
)
|
|
124
|
+
# TODO: add TTA
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
def _initialize_statistics(self) -> tuple[Stats, Optional[Stats]]:
|
|
128
|
+
# TODO: add running stats
|
|
129
|
+
# Currently assume that stats are provided in the configuration
|
|
130
|
+
input_stats = Stats(self.config.image_means, self.config.image_stds)
|
|
131
|
+
target_stats = None
|
|
132
|
+
if isinstance(self.config, DataConfig):
|
|
133
|
+
target_means = self.config.target_means
|
|
134
|
+
target_stds = self.config.target_stds
|
|
135
|
+
if target_means is not None and target_stds is not None:
|
|
136
|
+
target_stats = Stats(target_means, target_stds)
|
|
137
|
+
return input_stats, target_stats
|
|
138
|
+
|
|
139
|
+
def __len__(self):
|
|
140
|
+
return self.patching_strategy.n_patches
|
|
141
|
+
|
|
142
|
+
def _create_image_region(
|
|
143
|
+
self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
|
|
144
|
+
) -> ImageRegionData:
|
|
145
|
+
data_idx = patch_spec["data_idx"]
|
|
146
|
+
return ImageRegionData(
|
|
147
|
+
data=patch,
|
|
148
|
+
source=extractor.image_stacks[data_idx].source,
|
|
149
|
+
dtype=str(extractor.image_stacks[data_idx].data_dtype),
|
|
150
|
+
data_shape=extractor.image_stacks[data_idx].data_shape,
|
|
151
|
+
# TODO: should it be axes of the original image instead?
|
|
152
|
+
axes=self.config.axes,
|
|
153
|
+
region_spec=patch_spec,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def __getitem__(
|
|
157
|
+
self, index: int
|
|
158
|
+
) -> tuple[ImageRegionData, Optional[ImageRegionData]]:
|
|
159
|
+
patch_spec = self.patching_strategy.get_patch_spec(index)
|
|
160
|
+
input_patch = self.input_extractor.extract_patch(**patch_spec)
|
|
161
|
+
|
|
162
|
+
target_patch = (
|
|
163
|
+
self.target_extractor.extract_patch(**patch_spec)
|
|
164
|
+
if self.target_extractor is not None
|
|
165
|
+
else None
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if self.transforms is not None:
|
|
169
|
+
input_patch, target_patch = self.transforms(input_patch, target_patch)
|
|
170
|
+
|
|
171
|
+
input_data = self._create_image_region(
|
|
172
|
+
patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if target_patch is not None and self.target_extractor is not None:
|
|
176
|
+
target_data = self._create_image_region(
|
|
177
|
+
patch=target_patch,
|
|
178
|
+
patch_spec=patch_spec,
|
|
179
|
+
extractor=self.target_extractor,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
target_data = None
|
|
183
|
+
|
|
184
|
+
return input_data, target_data
|