careamics 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/RECORD +50 -35
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A module for utility functions that adapts the new dataset outputs to work with previous
|
|
3
|
+
code until it is updated.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import cast
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
|
|
12
|
+
from careamics.config.tile_information import TileInformation
|
|
13
|
+
|
|
14
|
+
from .dataset import ImageRegionData
|
|
15
|
+
from .patching_strategies import TileSpecs
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def imageregions_to_tileinfos(
|
|
19
|
+
image_regions: Sequence[ImageRegionData],
|
|
20
|
+
) -> list[tuple[NDArray, list[TileInformation]]]:
|
|
21
|
+
"""
|
|
22
|
+
Converts a series of `TileSpecs` dictionaries to `TileInformation` pydantic class.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
image_regions : sequence of ImageRegionData
|
|
27
|
+
A list of ImageRegionData, it must have an instance of `TileSpecs` as it's
|
|
28
|
+
`region_data` field.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
list of TileInformation
|
|
33
|
+
The converted tile information.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
tile_infos: list[TileInformation] = []
|
|
37
|
+
|
|
38
|
+
data = [image_region.data for image_region in image_regions]
|
|
39
|
+
tile_specs = [image_region.region_spec for image_region in image_regions]
|
|
40
|
+
|
|
41
|
+
data_indices: NDArray[np.int_] = np.array(
|
|
42
|
+
[tile_spec["data_idx"] for tile_spec in tile_specs], dtype=int
|
|
43
|
+
)
|
|
44
|
+
unique_data_indices = np.unique(data_indices)
|
|
45
|
+
# data_idx denotes which image stack a patch belongs to
|
|
46
|
+
# separate TileSpecs by image_stack
|
|
47
|
+
for data_idx in unique_data_indices:
|
|
48
|
+
# collect all ImageRegions
|
|
49
|
+
data_image_regions: list[ImageRegionData] = [
|
|
50
|
+
image_region
|
|
51
|
+
for image_region in image_regions
|
|
52
|
+
if image_region.region_spec["data_idx"] == data_idx
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
# --- find last indices
|
|
56
|
+
# make sure tiles belonging to the same sample are together
|
|
57
|
+
data_image_regions.sort(
|
|
58
|
+
key=lambda image_region: image_region.region_spec["sample_idx"]
|
|
59
|
+
)
|
|
60
|
+
sample_indices = np.array(
|
|
61
|
+
[
|
|
62
|
+
image_region.region_spec["sample_idx"]
|
|
63
|
+
for image_region in data_image_regions
|
|
64
|
+
]
|
|
65
|
+
)
|
|
66
|
+
# reverse array so indices returned are at far edge
|
|
67
|
+
_, unique_indices = np.unique(sample_indices[::-1], return_index=True)
|
|
68
|
+
# un reverse indices
|
|
69
|
+
last_indices = len(sample_indices) - 1 - unique_indices
|
|
70
|
+
|
|
71
|
+
# convert each ImageRegionData to tile_info
|
|
72
|
+
for i, image_region in enumerate(data_image_regions):
|
|
73
|
+
last_tile = i in last_indices
|
|
74
|
+
tile_info = _imageregion_to_tileinfo(image_region, last_tile)
|
|
75
|
+
tile_infos.append(tile_info)
|
|
76
|
+
|
|
77
|
+
return [(data, [tile_info]) for data, tile_info in zip(data, tile_infos)]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _imageregion_to_tileinfo(
|
|
81
|
+
image_region: ImageRegionData, last_tile: bool
|
|
82
|
+
) -> TileInformation:
|
|
83
|
+
"""
|
|
84
|
+
Convert a single `ImageRegionData` instance to a `TileInformation` instance. Whether
|
|
85
|
+
it is the last tile in a sequence needs to be supplied.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
image_region : ImageRegionData
|
|
90
|
+
An instance of `ImageRegionData`, it must have an instance of `TileSpecs` as
|
|
91
|
+
it's `region_data` field.
|
|
92
|
+
last_tile : bool
|
|
93
|
+
Whether a tile is the last tile in a sequence, for stitching.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
TileInformation
|
|
98
|
+
A tile information object.
|
|
99
|
+
|
|
100
|
+
Raises
|
|
101
|
+
------
|
|
102
|
+
KeyError
|
|
103
|
+
If `image_region.region_spec` does not contain the keys: {'crop_coords',
|
|
104
|
+
'crop_size', 'stitch_coords'}.
|
|
105
|
+
"""
|
|
106
|
+
patch_spec = image_region.region_spec
|
|
107
|
+
data_shape = image_region.data_shape
|
|
108
|
+
|
|
109
|
+
# TODO: In python 3.11 and greater, NamedTuples can inherit from Generic
|
|
110
|
+
# so we could do image_region: ImageRegionData[TileSpecs]
|
|
111
|
+
# and not have to do this check here + cast
|
|
112
|
+
# make sure image_region.region_spec is TileSpec
|
|
113
|
+
if (
|
|
114
|
+
("crop_coords" not in patch_spec)
|
|
115
|
+
or ("crop_size" not in patch_spec)
|
|
116
|
+
or ("stitch_coords" not in patch_spec)
|
|
117
|
+
):
|
|
118
|
+
raise KeyError(
|
|
119
|
+
"Could not find all keys: {'crop_coords', 'crop_size', 'stitch_coords'} in "
|
|
120
|
+
"`image_region.region_spec`."
|
|
121
|
+
)
|
|
122
|
+
tile_spec = cast(TileSpecs, patch_spec) # ugly cast for mypy
|
|
123
|
+
return _tilespec_to_tileinfo(tile_spec, data_shape, last_tile)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _tilespec_to_tileinfo(
|
|
127
|
+
tile_spec: TileSpecs, data_shape: Sequence[int], last_tile: bool
|
|
128
|
+
) -> TileInformation:
|
|
129
|
+
"""
|
|
130
|
+
Convert a single `TileSpec` to a `TileInformation`. Whether it is the last tile
|
|
131
|
+
needs to be supplied.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
tile_spec : TileSpecs
|
|
136
|
+
A tile spec dictionary.
|
|
137
|
+
data_shape : sequence of int
|
|
138
|
+
The original shape of the data the tile came from, labeling the dimensions of
|
|
139
|
+
axes SC(Z)YX.
|
|
140
|
+
last_tile : bool
|
|
141
|
+
Whether a tile is the last tile in a sequence, for stitching.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
TileInformation
|
|
146
|
+
A tile information object.
|
|
147
|
+
"""
|
|
148
|
+
overlap_crop_coords = tuple(
|
|
149
|
+
(
|
|
150
|
+
tile_spec["crop_coords"][i],
|
|
151
|
+
tile_spec["crop_coords"][i] + tile_spec["crop_size"][i],
|
|
152
|
+
)
|
|
153
|
+
for i in range(len(tile_spec["crop_coords"]))
|
|
154
|
+
)
|
|
155
|
+
stitch_coords = tuple(
|
|
156
|
+
(
|
|
157
|
+
tile_spec["stitch_coords"][i],
|
|
158
|
+
tile_spec["stitch_coords"][i] + tile_spec["crop_size"][i],
|
|
159
|
+
)
|
|
160
|
+
for i in range(len(tile_spec["crop_coords"]))
|
|
161
|
+
)
|
|
162
|
+
return TileInformation(
|
|
163
|
+
array_shape=tuple(data_shape[1:]), # remove sample dimension
|
|
164
|
+
last_tile=last_tile,
|
|
165
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
166
|
+
stitch_coords=stitch_coords,
|
|
167
|
+
sample_id=tile_spec["sample_idx"],
|
|
168
|
+
)
|
|
@@ -1,10 +1,5 @@
|
|
|
1
|
-
__all__ = [
|
|
2
|
-
"ImageStackLoader",
|
|
3
|
-
"PatchExtractor",
|
|
4
|
-
"create_patch_extractor",
|
|
5
|
-
"get_image_stack_loader",
|
|
6
|
-
]
|
|
1
|
+
__all__ = ["GenericImageStack", "ImageStackLoader", "PatchExtractor"]
|
|
7
2
|
|
|
8
|
-
from .
|
|
3
|
+
from .image_stack import GenericImageStack
|
|
4
|
+
from .image_stack_loader import ImageStackLoader
|
|
9
5
|
from .patch_extractor import PatchExtractor
|
|
10
|
-
from .patch_extractor_factory import create_patch_extractor
|
|
@@ -11,9 +11,11 @@ from zarr.storage import FSStore
|
|
|
11
11
|
|
|
12
12
|
from careamics.config import DataConfig
|
|
13
13
|
from careamics.config.support import SupportedData
|
|
14
|
-
from careamics.dataset_ng.patch_extractor import create_patch_extractor
|
|
15
14
|
from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
|
|
16
15
|
from careamics.dataset_ng.patch_extractor.image_stack_loader import ImageStackLoader
|
|
16
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
17
|
+
create_custom_image_stack_extractor,
|
|
18
|
+
)
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
# %%
|
|
@@ -94,12 +96,12 @@ image_stack_loader: ImageStackLoader = custom_image_stack_loader
|
|
|
94
96
|
|
|
95
97
|
# %%
|
|
96
98
|
# So pylance knows that datatype is custom to match function overloads
|
|
97
|
-
assert data_config.data_type is SupportedData.CUSTOM
|
|
99
|
+
assert SupportedData(data_config.data_type) is SupportedData.CUSTOM
|
|
98
100
|
|
|
99
|
-
patch_extractor =
|
|
101
|
+
patch_extractor = create_custom_image_stack_extractor(
|
|
100
102
|
source={"store": store, "data_paths": data_paths},
|
|
101
103
|
axes=data_config.axes,
|
|
102
|
-
data_type=data_config.data_type,
|
|
104
|
+
data_type=SupportedData(data_config.data_type),
|
|
103
105
|
image_stack_loader=custom_image_stack_loader,
|
|
104
106
|
)
|
|
105
107
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
__all__ = [
|
|
2
|
+
"GenericImageStack",
|
|
2
3
|
"ImageStack",
|
|
3
4
|
"InMemoryImageStack",
|
|
4
5
|
"ZarrImageStack",
|
|
5
6
|
]
|
|
6
7
|
|
|
7
|
-
from .image_stack_protocol import ImageStack
|
|
8
|
+
from .image_stack_protocol import GenericImageStack, ImageStack
|
|
8
9
|
from .in_memory_image_stack import InMemoryImageStack
|
|
9
10
|
from .zarr_image_stack import ZarrImageStack
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Literal, Protocol, Union
|
|
3
|
+
from typing import Literal, Protocol, TypeVar, Union
|
|
4
4
|
|
|
5
5
|
from numpy.typing import DTypeLike, NDArray
|
|
6
6
|
|
|
@@ -51,3 +51,7 @@ class ImageStack(Protocol):
|
|
|
51
51
|
A patch of the image data from a particlular sample. It will have the
|
|
52
52
|
dimensions C(Z)YX.
|
|
53
53
|
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
GenericImageStack = TypeVar("GenericImageStack", bound=ImageStack, covariant=True)
|
|
@@ -1,20 +1,11 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
Any,
|
|
5
|
-
Optional,
|
|
6
|
-
Protocol,
|
|
7
|
-
Union,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
from numpy.typing import NDArray
|
|
2
|
+
from typing import Any, Protocol
|
|
3
|
+
|
|
11
4
|
from typing_extensions import ParamSpec
|
|
12
5
|
|
|
13
|
-
from careamics.config.support import SupportedData
|
|
14
|
-
from careamics.file_io.read import ReadFunc
|
|
15
6
|
from careamics.utils import BaseEnum
|
|
16
7
|
|
|
17
|
-
from .image_stack import
|
|
8
|
+
from .image_stack import GenericImageStack
|
|
18
9
|
|
|
19
10
|
P = ParamSpec("P")
|
|
20
11
|
|
|
@@ -23,7 +14,7 @@ class SupportedDataDev(str, BaseEnum):
|
|
|
23
14
|
ZARR = "zarr"
|
|
24
15
|
|
|
25
16
|
|
|
26
|
-
class ImageStackLoader(Protocol[P]):
|
|
17
|
+
class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
27
18
|
"""
|
|
28
19
|
Protocol to define how `ImageStacks` should be loaded.
|
|
29
20
|
|
|
@@ -76,65 +67,4 @@ class ImageStackLoader(Protocol[P]):
|
|
|
76
67
|
|
|
77
68
|
def __call__(
|
|
78
69
|
self, source: Any, axes: str, *args: P.args, **kwargs: P.kwargs
|
|
79
|
-
) -> Sequence[
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def from_arrays(
|
|
83
|
-
source: Sequence[NDArray], axes: str, *args, **kwargs
|
|
84
|
-
) -> list[InMemoryImageStack]:
|
|
85
|
-
return [InMemoryImageStack.from_array(data=array, axes=axes) for array in source]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
# TODO: change source to directory path? Like in current implementation
|
|
89
|
-
# Advantage of having a list is the user can match input and target order themselves
|
|
90
|
-
def from_tiff_files(
|
|
91
|
-
source: Sequence[Path], axes: str, *args, **kwargs
|
|
92
|
-
) -> list[InMemoryImageStack]:
|
|
93
|
-
return [InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source]
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
# TODO: change source to directory path? Like in current implementation
|
|
97
|
-
# Advantage of having a list is the user can match input and target order themselves
|
|
98
|
-
def from_custom_file_type(
|
|
99
|
-
source: Sequence[Path],
|
|
100
|
-
axes: str,
|
|
101
|
-
read_func: ReadFunc,
|
|
102
|
-
read_kwargs: dict[str, Any],
|
|
103
|
-
*args,
|
|
104
|
-
**kwargs,
|
|
105
|
-
) -> list[InMemoryImageStack]:
|
|
106
|
-
return [
|
|
107
|
-
InMemoryImageStack.from_custom_file_type(
|
|
108
|
-
path=path,
|
|
109
|
-
axes=axes,
|
|
110
|
-
read_func=read_func,
|
|
111
|
-
**read_kwargs,
|
|
112
|
-
)
|
|
113
|
-
for path in source
|
|
114
|
-
]
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def from_ome_zarr_files(
|
|
118
|
-
source: Sequence[Path], axes: str, *args, **kwargs
|
|
119
|
-
) -> list[ZarrImageStack]:
|
|
120
|
-
# NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
|
|
121
|
-
return [ZarrImageStack.from_ome_zarr(path) for path in source]
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def get_image_stack_loader(
|
|
125
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
126
|
-
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
127
|
-
) -> ImageStackLoader:
|
|
128
|
-
if data_type == SupportedData.ARRAY:
|
|
129
|
-
return from_arrays
|
|
130
|
-
elif data_type == SupportedData.TIFF:
|
|
131
|
-
return from_tiff_files
|
|
132
|
-
elif data_type == "zarr": # temp for testing until zarr is added to SupportedData
|
|
133
|
-
return from_ome_zarr_files
|
|
134
|
-
elif data_type == SupportedData.CUSTOM:
|
|
135
|
-
if image_stack_loader is None:
|
|
136
|
-
return from_custom_file_type
|
|
137
|
-
else:
|
|
138
|
-
return image_stack_loader
|
|
139
|
-
else:
|
|
140
|
-
raise ValueError
|
|
70
|
+
) -> Sequence[GenericImageStack]: ...
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
+
from typing import Generic
|
|
2
3
|
|
|
3
4
|
from numpy.typing import NDArray
|
|
4
5
|
|
|
5
|
-
from .image_stack import
|
|
6
|
+
from .image_stack import GenericImageStack
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
class PatchExtractor:
|
|
9
|
+
class PatchExtractor(Generic[GenericImageStack]):
|
|
9
10
|
"""
|
|
10
11
|
A class for extracting patches from multiple image stacks.
|
|
11
12
|
"""
|
|
12
13
|
|
|
13
|
-
def __init__(self, image_stacks: Sequence[
|
|
14
|
-
self.image_stacks: list[
|
|
14
|
+
def __init__(self, image_stacks: Sequence[GenericImageStack]):
|
|
15
|
+
self.image_stacks: list[GenericImageStack] = list(image_stacks)
|
|
15
16
|
|
|
16
17
|
def extract_patch(
|
|
17
18
|
self,
|
|
@@ -1,29 +1,29 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
from typing_extensions import ParamSpec
|
|
7
7
|
|
|
8
|
-
from careamics.config.support import SupportedData
|
|
9
8
|
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
10
9
|
from careamics.file_io.read import ReadFunc
|
|
11
10
|
|
|
11
|
+
from .image_stack import (
|
|
12
|
+
GenericImageStack,
|
|
13
|
+
InMemoryImageStack,
|
|
14
|
+
ZarrImageStack,
|
|
15
|
+
)
|
|
12
16
|
from .image_stack_loader import (
|
|
13
17
|
ImageStackLoader,
|
|
14
|
-
SupportedDataDev,
|
|
15
|
-
get_image_stack_loader,
|
|
16
18
|
)
|
|
17
19
|
|
|
18
20
|
P = ParamSpec("P")
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
# Define overloads for each implemented ImageStackLoader case
|
|
22
23
|
# Array case
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
) -> PatchExtractor:
|
|
24
|
+
def create_array_extractor(
|
|
25
|
+
source: Sequence[NDArray[Any]], axes: str
|
|
26
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
27
27
|
"""
|
|
28
28
|
Create a patch extractor from a sequence of numpy arrays.
|
|
29
29
|
|
|
@@ -31,57 +31,78 @@ def create_patch_extractor(
|
|
|
31
31
|
----------
|
|
32
32
|
source: sequence of numpy.ndarray
|
|
33
33
|
The source arrays of the data.
|
|
34
|
-
|
|
35
|
-
The data
|
|
36
|
-
and `data_config.axes` should describe the axes of every array in the `source`.
|
|
34
|
+
axes: str
|
|
35
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
37
36
|
|
|
38
37
|
Returns
|
|
39
38
|
-------
|
|
40
39
|
PatchExtractor
|
|
41
40
|
"""
|
|
41
|
+
image_stacks = [
|
|
42
|
+
InMemoryImageStack.from_array(data=array, axes=axes) for array in source
|
|
43
|
+
]
|
|
44
|
+
return PatchExtractor(image_stacks)
|
|
45
|
+
|
|
42
46
|
|
|
47
|
+
# TIFF case
|
|
48
|
+
def create_tiff_extractor(
|
|
49
|
+
source: Sequence[Path], axes: str
|
|
50
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
51
|
+
"""
|
|
52
|
+
Create a patch extractor from a sequence of TIFF files.
|
|
43
53
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
source: sequence of Path
|
|
57
|
+
The source files for the data.
|
|
58
|
+
axes: str
|
|
59
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
PatchExtractor
|
|
64
|
+
"""
|
|
65
|
+
image_stacks = [
|
|
66
|
+
InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source
|
|
67
|
+
]
|
|
68
|
+
return PatchExtractor(image_stacks)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ZARR case
|
|
72
|
+
def create_ome_zarr_extractor(
|
|
47
73
|
source: Sequence[Path],
|
|
48
74
|
axes: str,
|
|
49
|
-
|
|
50
|
-
) -> PatchExtractor:
|
|
75
|
+
) -> PatchExtractor[ZarrImageStack]:
|
|
51
76
|
"""
|
|
52
|
-
Create a patch extractor from a sequence of
|
|
77
|
+
Create a patch extractor from a sequence of OME ZARR files.
|
|
53
78
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
If the files are ZARR files they must follow the OME standard. If you have ZARR
|
|
57
|
-
files that do not follow the OME standard, see documentation on how to create
|
|
58
|
-
a custom `image_stack_loader`. (TODO: Add link).
|
|
79
|
+
If you have ZARR files that do not follow the OME standard, see documentation on
|
|
80
|
+
how to create a custom `image_stack_loader`. (TODO: Add link).
|
|
59
81
|
|
|
60
82
|
Parameters
|
|
61
83
|
----------
|
|
62
84
|
source: sequence of Path
|
|
63
85
|
The source files for the data.
|
|
64
|
-
|
|
65
|
-
The data
|
|
66
|
-
"zarr", and `data_config.axes` should describe the axes of every image in the
|
|
67
|
-
`source`.
|
|
86
|
+
axes: str
|
|
87
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
68
88
|
|
|
69
89
|
Returns
|
|
70
90
|
-------
|
|
71
91
|
PatchExtractor
|
|
72
92
|
"""
|
|
93
|
+
# NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
|
|
94
|
+
image_stacks = [ZarrImageStack.from_ome_zarr(path) for path in source]
|
|
95
|
+
return PatchExtractor(image_stacks)
|
|
73
96
|
|
|
74
97
|
|
|
75
98
|
# Custom file type case (loaded into memory)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
source: Any,
|
|
99
|
+
def create_custom_file_extractor(
|
|
100
|
+
source: Sequence[Path],
|
|
79
101
|
axes: str,
|
|
80
|
-
data_type: Literal[SupportedData.CUSTOM],
|
|
81
102
|
*,
|
|
82
103
|
read_func: ReadFunc,
|
|
83
104
|
read_kwargs: dict[str, Any],
|
|
84
|
-
) -> PatchExtractor:
|
|
105
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
85
106
|
"""
|
|
86
107
|
Create a patch extractor from a sequence of files of a custom type.
|
|
87
108
|
|
|
@@ -89,8 +110,8 @@ def create_patch_extractor(
|
|
|
89
110
|
----------
|
|
90
111
|
source: sequence of Path
|
|
91
112
|
The source files for the data.
|
|
92
|
-
|
|
93
|
-
The data
|
|
113
|
+
axes: str
|
|
114
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
94
115
|
read_func : ReadFunc
|
|
95
116
|
A function to read the custom file type, see the `ReadFunc` protocol.
|
|
96
117
|
read_kwargs : dict of {str: Any}
|
|
@@ -100,18 +121,28 @@ def create_patch_extractor(
|
|
|
100
121
|
-------
|
|
101
122
|
PatchExtractor
|
|
102
123
|
"""
|
|
124
|
+
# TODO: lazy loading custom files
|
|
125
|
+
image_stacks = [
|
|
126
|
+
InMemoryImageStack.from_custom_file_type(
|
|
127
|
+
path=path,
|
|
128
|
+
axes=axes,
|
|
129
|
+
read_func=read_func,
|
|
130
|
+
**read_kwargs,
|
|
131
|
+
)
|
|
132
|
+
for path in source
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
return PatchExtractor(image_stacks)
|
|
103
136
|
|
|
104
137
|
|
|
105
138
|
# Custom ImageStackLoader case
|
|
106
|
-
|
|
107
|
-
def create_patch_extractor(
|
|
139
|
+
def create_custom_image_stack_extractor(
|
|
108
140
|
source: Any,
|
|
109
141
|
axes: str,
|
|
110
|
-
|
|
111
|
-
image_stack_loader: ImageStackLoader[P],
|
|
142
|
+
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
112
143
|
*args: P.args,
|
|
113
144
|
**kwargs: P.kwargs,
|
|
114
|
-
) -> PatchExtractor:
|
|
145
|
+
) -> PatchExtractor[GenericImageStack]:
|
|
115
146
|
"""
|
|
116
147
|
Create a patch extractor using a custom `ImageStackLoader`.
|
|
117
148
|
|
|
@@ -127,8 +158,8 @@ def create_patch_extractor(
|
|
|
127
158
|
----------
|
|
128
159
|
source: sequence of Path
|
|
129
160
|
The source files for the data.
|
|
130
|
-
|
|
131
|
-
The data
|
|
161
|
+
axes: str
|
|
162
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
132
163
|
image_stack_loader: ImageStackLoader
|
|
133
164
|
A custom image stack loader callable.
|
|
134
165
|
*args: Any
|
|
@@ -140,69 +171,5 @@ def create_patch_extractor(
|
|
|
140
171
|
-------
|
|
141
172
|
PatchExtractor
|
|
142
173
|
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# final overload to match the implentation function signature
|
|
146
|
-
# Need this so it works later in the code
|
|
147
|
-
# (bec there aren't created overloads for create_patch_extractors below)
|
|
148
|
-
@overload
|
|
149
|
-
def create_patch_extractor(
|
|
150
|
-
source: Any,
|
|
151
|
-
axes: str,
|
|
152
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
153
|
-
image_stack_loader: Optional[ImageStackLoader[P]] = None,
|
|
154
|
-
*args: P.args,
|
|
155
|
-
**kwargs: P.kwargs,
|
|
156
|
-
) -> PatchExtractor: ...
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def create_patch_extractor(
|
|
160
|
-
source: Any,
|
|
161
|
-
axes: str,
|
|
162
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
163
|
-
image_stack_loader: Optional[ImageStackLoader[P]] = None,
|
|
164
|
-
*args: P.args,
|
|
165
|
-
**kwargs: P.kwargs,
|
|
166
|
-
) -> PatchExtractor:
|
|
167
|
-
# TODO: Do we need to catch data_config.data_type and source mismatches?
|
|
168
|
-
# e.g. data_config.data_type is "array" but source is not Sequence[NDArray]
|
|
169
|
-
loader: ImageStackLoader[P] = get_image_stack_loader(data_type, image_stack_loader)
|
|
170
|
-
image_stacks = loader(source, axes, *args, **kwargs)
|
|
174
|
+
image_stacks = image_stack_loader(source, axes, *args, **kwargs)
|
|
171
175
|
return PatchExtractor(image_stacks)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# TODO: Remove this and just call `create_patch_extractor` within the Dataset class
|
|
175
|
-
# Keeping for consistency for now
|
|
176
|
-
def create_patch_extractors(
|
|
177
|
-
source: Any,
|
|
178
|
-
target_source: Optional[Any],
|
|
179
|
-
axes: str,
|
|
180
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
181
|
-
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
182
|
-
*args,
|
|
183
|
-
**kwargs,
|
|
184
|
-
) -> tuple[PatchExtractor, Optional[PatchExtractor]]:
|
|
185
|
-
|
|
186
|
-
# --- data extractor
|
|
187
|
-
patch_extractor: PatchExtractor = create_patch_extractor(
|
|
188
|
-
source,
|
|
189
|
-
axes,
|
|
190
|
-
data_type,
|
|
191
|
-
image_stack_loader,
|
|
192
|
-
*args,
|
|
193
|
-
**kwargs,
|
|
194
|
-
)
|
|
195
|
-
# --- optional target extractor
|
|
196
|
-
if target_source is not None:
|
|
197
|
-
target_patch_extractor = create_patch_extractor(
|
|
198
|
-
target_source,
|
|
199
|
-
axes,
|
|
200
|
-
data_type,
|
|
201
|
-
image_stack_loader,
|
|
202
|
-
*args,
|
|
203
|
-
**kwargs,
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
return patch_extractor, target_patch_extractor
|
|
207
|
-
|
|
208
|
-
return patch_extractor, None
|
|
@@ -4,8 +4,13 @@ __all__ = [
|
|
|
4
4
|
"PatchingStrategy",
|
|
5
5
|
"RandomPatchingStrategy",
|
|
6
6
|
"SequentialPatchingStrategy",
|
|
7
|
+
"TileSpecs",
|
|
8
|
+
"TilingStrategy",
|
|
9
|
+
"WholeSamplePatchingStrategy",
|
|
7
10
|
]
|
|
8
11
|
|
|
9
|
-
from .patching_strategy_protocol import PatchingStrategy, PatchSpecs
|
|
12
|
+
from .patching_strategy_protocol import PatchingStrategy, PatchSpecs, TileSpecs
|
|
10
13
|
from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
|
|
11
14
|
from .sequential_patching import SequentialPatchingStrategy
|
|
15
|
+
from .tiling_strategy import TilingStrategy
|
|
16
|
+
from .whole_sample import WholeSamplePatchingStrategy
|
|
@@ -28,6 +28,37 @@ class PatchSpecs(TypedDict):
|
|
|
28
28
|
patch_size: Sequence[int]
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class TileSpecs(PatchSpecs):
|
|
32
|
+
"""A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
data_idx: int
|
|
37
|
+
Determines which `ImageStack` a patch belongs to, within a series of
|
|
38
|
+
`ImageStack`s.
|
|
39
|
+
sample_idx: int
|
|
40
|
+
Determines which sample a patch belongs to, within an `ImageStack`.
|
|
41
|
+
coords: sequence of int
|
|
42
|
+
The top-left (and first z-slice for 3D data) of a patch. The sequence will have
|
|
43
|
+
length 2 or 3, for 2D and 3D data respectively.
|
|
44
|
+
patch_size: sequence of int
|
|
45
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
|
|
46
|
+
respectively.
|
|
47
|
+
crop_coords: sequence of int
|
|
48
|
+
The top-left side of where the tile will be cropped, in coordinates relative
|
|
49
|
+
to the tile.
|
|
50
|
+
crop_size: sequence of int
|
|
51
|
+
The size of the cropped tile.
|
|
52
|
+
stitch_coords: sequence of int
|
|
53
|
+
Where the tile will be stitched back into an image, taking into account
|
|
54
|
+
that the tile will be cropped, in coords relative to the image.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
crop_coords: Sequence[int]
|
|
58
|
+
crop_size: Sequence[int]
|
|
59
|
+
stitch_coords: Sequence[int]
|
|
60
|
+
|
|
61
|
+
|
|
31
62
|
class PatchingStrategy(Protocol):
|
|
32
63
|
"""
|
|
33
64
|
An interface for patching strategies.
|