careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +24 -7
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +55 -4
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +41 -4
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/optimizer_models.py +1 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/training_model.py +0 -2
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +229 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +451 -0
- careamics/dataset_ng/legacy_interoperability.py +170 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +678 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
- careamics/lightning/lightning_module.py +5 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/compose.py +1 -0
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/normalize.py +18 -7
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +25 -11
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,29 +1,30 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Literal
|
|
3
|
+
from typing import Any, Literal
|
|
4
4
|
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
from typing_extensions import ParamSpec
|
|
7
7
|
|
|
8
|
-
from careamics.config.support import SupportedData
|
|
9
8
|
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
10
9
|
from careamics.file_io.read import ReadFunc
|
|
11
10
|
|
|
11
|
+
from .image_stack import (
|
|
12
|
+
CziImageStack,
|
|
13
|
+
GenericImageStack,
|
|
14
|
+
InMemoryImageStack,
|
|
15
|
+
ZarrImageStack,
|
|
16
|
+
)
|
|
12
17
|
from .image_stack_loader import (
|
|
13
18
|
ImageStackLoader,
|
|
14
|
-
SupportedDataDev,
|
|
15
|
-
get_image_stack_loader,
|
|
16
19
|
)
|
|
17
20
|
|
|
18
21
|
P = ParamSpec("P")
|
|
19
22
|
|
|
20
23
|
|
|
21
|
-
# Define overloads for each implemented ImageStackLoader case
|
|
22
24
|
# Array case
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
) -> PatchExtractor:
|
|
25
|
+
def create_array_extractor(
|
|
26
|
+
source: Sequence[NDArray[Any]], axes: str
|
|
27
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
27
28
|
"""
|
|
28
29
|
Create a patch extractor from a sequence of numpy arrays.
|
|
29
30
|
|
|
@@ -31,57 +32,119 @@ def create_patch_extractor(
|
|
|
31
32
|
----------
|
|
32
33
|
source: sequence of numpy.ndarray
|
|
33
34
|
The source arrays of the data.
|
|
34
|
-
|
|
35
|
-
The data
|
|
36
|
-
and `data_config.axes` should describe the axes of every array in the `source`.
|
|
35
|
+
axes: str
|
|
36
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
37
37
|
|
|
38
38
|
Returns
|
|
39
39
|
-------
|
|
40
40
|
PatchExtractor
|
|
41
41
|
"""
|
|
42
|
+
image_stacks = [
|
|
43
|
+
InMemoryImageStack.from_array(data=array, axes=axes) for array in source
|
|
44
|
+
]
|
|
45
|
+
return PatchExtractor(image_stacks)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# TIFF case
|
|
49
|
+
def create_tiff_extractor(
|
|
50
|
+
source: Sequence[Path], axes: str
|
|
51
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
52
|
+
"""
|
|
53
|
+
Create a patch extractor from a sequence of TIFF files.
|
|
42
54
|
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
source: sequence of Path
|
|
58
|
+
The source files for the data.
|
|
59
|
+
axes: str
|
|
60
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
PatchExtractor
|
|
65
|
+
"""
|
|
66
|
+
image_stacks = [
|
|
67
|
+
InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source
|
|
68
|
+
]
|
|
69
|
+
return PatchExtractor(image_stacks)
|
|
43
70
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def
|
|
71
|
+
|
|
72
|
+
# ZARR case
|
|
73
|
+
def create_ome_zarr_extractor(
|
|
47
74
|
source: Sequence[Path],
|
|
48
75
|
axes: str,
|
|
49
|
-
|
|
50
|
-
|
|
76
|
+
) -> PatchExtractor[ZarrImageStack]:
|
|
77
|
+
"""
|
|
78
|
+
Create a patch extractor from a sequence of OME ZARR files.
|
|
79
|
+
|
|
80
|
+
If you have ZARR files that do not follow the OME standard, see documentation on
|
|
81
|
+
how to create a custom `image_stack_loader`. (TODO: Add link).
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
source: sequence of Path
|
|
86
|
+
The source files for the data.
|
|
87
|
+
axes: str
|
|
88
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
PatchExtractor
|
|
51
93
|
"""
|
|
52
|
-
|
|
94
|
+
# NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
|
|
95
|
+
image_stacks = [ZarrImageStack.from_ome_zarr(path) for path in source]
|
|
96
|
+
return PatchExtractor(image_stacks)
|
|
97
|
+
|
|
53
98
|
|
|
54
|
-
|
|
99
|
+
# CZI case
|
|
100
|
+
def create_czi_extractor(
|
|
101
|
+
source: Sequence[Path],
|
|
102
|
+
axes: str,
|
|
103
|
+
) -> PatchExtractor[CziImageStack]:
|
|
104
|
+
"""
|
|
105
|
+
Create a patch extractor from a sequence of CZI files.
|
|
55
106
|
|
|
56
|
-
If the
|
|
57
|
-
|
|
58
|
-
a custom `image_stack_loader`. (TODO: Add link).
|
|
107
|
+
If the CZI files contain multiple scenes, one patch extractor will be created for
|
|
108
|
+
each scene.
|
|
59
109
|
|
|
60
110
|
Parameters
|
|
61
111
|
----------
|
|
62
112
|
source: sequence of Path
|
|
63
113
|
The source files for the data.
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
`
|
|
114
|
+
axes: str
|
|
115
|
+
Specifies which axes of the data to use and how.
|
|
116
|
+
If this string ends with `"ZYX"` or `"TYX"`, the data will consist of 3-D
|
|
117
|
+
patches, using `Z` or `T` as third dimension, respectively.
|
|
118
|
+
If the string does not end with "ZYX", the data will consist of 2-D patches.
|
|
68
119
|
|
|
69
120
|
Returns
|
|
70
121
|
-------
|
|
71
122
|
PatchExtractor
|
|
72
123
|
"""
|
|
124
|
+
depth_axis: Literal["none", "Z", "T"] = "none"
|
|
125
|
+
if axes.endswith("TYX"):
|
|
126
|
+
depth_axis = "T"
|
|
127
|
+
elif axes.endswith("ZYX"):
|
|
128
|
+
depth_axis = "Z"
|
|
129
|
+
|
|
130
|
+
image_stacks: list[CziImageStack] = []
|
|
131
|
+
for path in source:
|
|
132
|
+
scene_rectangles = CziImageStack.get_bounding_rectangles(path)
|
|
133
|
+
image_stacks.extend(
|
|
134
|
+
CziImageStack(path, scene=scene, depth_axis=depth_axis)
|
|
135
|
+
for scene in scene_rectangles.keys()
|
|
136
|
+
)
|
|
137
|
+
return PatchExtractor(image_stacks)
|
|
73
138
|
|
|
74
139
|
|
|
75
140
|
# Custom file type case (loaded into memory)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
source: Any,
|
|
141
|
+
def create_custom_file_extractor(
|
|
142
|
+
source: Sequence[Path],
|
|
79
143
|
axes: str,
|
|
80
|
-
data_type: Literal[SupportedData.CUSTOM],
|
|
81
144
|
*,
|
|
82
145
|
read_func: ReadFunc,
|
|
83
146
|
read_kwargs: dict[str, Any],
|
|
84
|
-
) -> PatchExtractor:
|
|
147
|
+
) -> PatchExtractor[InMemoryImageStack]:
|
|
85
148
|
"""
|
|
86
149
|
Create a patch extractor from a sequence of files of a custom type.
|
|
87
150
|
|
|
@@ -89,8 +152,8 @@ def create_patch_extractor(
|
|
|
89
152
|
----------
|
|
90
153
|
source: sequence of Path
|
|
91
154
|
The source files for the data.
|
|
92
|
-
|
|
93
|
-
The data
|
|
155
|
+
axes: str
|
|
156
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
94
157
|
read_func : ReadFunc
|
|
95
158
|
A function to read the custom file type, see the `ReadFunc` protocol.
|
|
96
159
|
read_kwargs : dict of {str: Any}
|
|
@@ -100,18 +163,28 @@ def create_patch_extractor(
|
|
|
100
163
|
-------
|
|
101
164
|
PatchExtractor
|
|
102
165
|
"""
|
|
166
|
+
# TODO: lazy loading custom files
|
|
167
|
+
image_stacks = [
|
|
168
|
+
InMemoryImageStack.from_custom_file_type(
|
|
169
|
+
path=path,
|
|
170
|
+
axes=axes,
|
|
171
|
+
read_func=read_func,
|
|
172
|
+
**read_kwargs,
|
|
173
|
+
)
|
|
174
|
+
for path in source
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
return PatchExtractor(image_stacks)
|
|
103
178
|
|
|
104
179
|
|
|
105
180
|
# Custom ImageStackLoader case
|
|
106
|
-
|
|
107
|
-
def create_patch_extractor(
|
|
181
|
+
def create_custom_image_stack_extractor(
|
|
108
182
|
source: Any,
|
|
109
183
|
axes: str,
|
|
110
|
-
|
|
111
|
-
image_stack_loader: ImageStackLoader[P],
|
|
184
|
+
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
112
185
|
*args: P.args,
|
|
113
186
|
**kwargs: P.kwargs,
|
|
114
|
-
) -> PatchExtractor:
|
|
187
|
+
) -> PatchExtractor[GenericImageStack]:
|
|
115
188
|
"""
|
|
116
189
|
Create a patch extractor using a custom `ImageStackLoader`.
|
|
117
190
|
|
|
@@ -127,8 +200,8 @@ def create_patch_extractor(
|
|
|
127
200
|
----------
|
|
128
201
|
source: sequence of Path
|
|
129
202
|
The source files for the data.
|
|
130
|
-
|
|
131
|
-
The data
|
|
203
|
+
axes: str
|
|
204
|
+
The original axes of the data, must be a subset of "STCZYX".
|
|
132
205
|
image_stack_loader: ImageStackLoader
|
|
133
206
|
A custom image stack loader callable.
|
|
134
207
|
*args: Any
|
|
@@ -140,69 +213,5 @@ def create_patch_extractor(
|
|
|
140
213
|
-------
|
|
141
214
|
PatchExtractor
|
|
142
215
|
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# final overload to match the implentation function signature
|
|
146
|
-
# Need this so it works later in the code
|
|
147
|
-
# (bec there aren't created overloads for create_patch_extractors below)
|
|
148
|
-
@overload
|
|
149
|
-
def create_patch_extractor(
|
|
150
|
-
source: Any,
|
|
151
|
-
axes: str,
|
|
152
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
153
|
-
image_stack_loader: Optional[ImageStackLoader[P]] = None,
|
|
154
|
-
*args: P.args,
|
|
155
|
-
**kwargs: P.kwargs,
|
|
156
|
-
) -> PatchExtractor: ...
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def create_patch_extractor(
|
|
160
|
-
source: Any,
|
|
161
|
-
axes: str,
|
|
162
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
163
|
-
image_stack_loader: Optional[ImageStackLoader[P]] = None,
|
|
164
|
-
*args: P.args,
|
|
165
|
-
**kwargs: P.kwargs,
|
|
166
|
-
) -> PatchExtractor:
|
|
167
|
-
# TODO: Do we need to catch data_config.data_type and source mismatches?
|
|
168
|
-
# e.g. data_config.data_type is "array" but source is not Sequence[NDArray]
|
|
169
|
-
loader: ImageStackLoader[P] = get_image_stack_loader(data_type, image_stack_loader)
|
|
170
|
-
image_stacks = loader(source, axes, *args, **kwargs)
|
|
216
|
+
image_stacks = image_stack_loader(source, axes, *args, **kwargs)
|
|
171
217
|
return PatchExtractor(image_stacks)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# TODO: Remove this and just call `create_patch_extractor` within the Dataset class
|
|
175
|
-
# Keeping for consistency for now
|
|
176
|
-
def create_patch_extractors(
|
|
177
|
-
source: Any,
|
|
178
|
-
target_source: Optional[Any],
|
|
179
|
-
axes: str,
|
|
180
|
-
data_type: Union[SupportedData, SupportedDataDev],
|
|
181
|
-
image_stack_loader: Optional[ImageStackLoader] = None,
|
|
182
|
-
*args,
|
|
183
|
-
**kwargs,
|
|
184
|
-
) -> tuple[PatchExtractor, Optional[PatchExtractor]]:
|
|
185
|
-
|
|
186
|
-
# --- data extractor
|
|
187
|
-
patch_extractor: PatchExtractor = create_patch_extractor(
|
|
188
|
-
source,
|
|
189
|
-
axes,
|
|
190
|
-
data_type,
|
|
191
|
-
image_stack_loader,
|
|
192
|
-
*args,
|
|
193
|
-
**kwargs,
|
|
194
|
-
)
|
|
195
|
-
# --- optional target extractor
|
|
196
|
-
if target_source is not None:
|
|
197
|
-
target_patch_extractor = create_patch_extractor(
|
|
198
|
-
target_source,
|
|
199
|
-
axes,
|
|
200
|
-
data_type,
|
|
201
|
-
image_stack_loader,
|
|
202
|
-
*args,
|
|
203
|
-
**kwargs,
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
return patch_extractor, target_patch_extractor
|
|
207
|
-
|
|
208
|
-
return patch_extractor, None
|
|
@@ -4,8 +4,13 @@ __all__ = [
|
|
|
4
4
|
"PatchingStrategy",
|
|
5
5
|
"RandomPatchingStrategy",
|
|
6
6
|
"SequentialPatchingStrategy",
|
|
7
|
+
"TileSpecs",
|
|
8
|
+
"TilingStrategy",
|
|
9
|
+
"WholeSamplePatchingStrategy",
|
|
7
10
|
]
|
|
8
11
|
|
|
9
|
-
from .patching_strategy_protocol import PatchingStrategy, PatchSpecs
|
|
12
|
+
from .patching_strategy_protocol import PatchingStrategy, PatchSpecs, TileSpecs
|
|
10
13
|
from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
|
|
11
14
|
from .sequential_patching import SequentialPatchingStrategy
|
|
15
|
+
from .tiling_strategy import TilingStrategy
|
|
16
|
+
from .whole_sample import WholeSamplePatchingStrategy
|
|
@@ -28,6 +28,37 @@ class PatchSpecs(TypedDict):
|
|
|
28
28
|
patch_size: Sequence[int]
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class TileSpecs(PatchSpecs):
|
|
32
|
+
"""A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
data_idx: int
|
|
37
|
+
Determines which `ImageStack` a patch belongs to, within a series of
|
|
38
|
+
`ImageStack`s.
|
|
39
|
+
sample_idx: int
|
|
40
|
+
Determines which sample a patch belongs to, within an `ImageStack`.
|
|
41
|
+
coords: sequence of int
|
|
42
|
+
The top-left (and first z-slice for 3D data) of a patch. The sequence will have
|
|
43
|
+
length 2 or 3, for 2D and 3D data respectively.
|
|
44
|
+
patch_size: sequence of int
|
|
45
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
|
|
46
|
+
respectively.
|
|
47
|
+
crop_coords: sequence of int
|
|
48
|
+
The top-left side of where the tile will be cropped, in coordinates relative
|
|
49
|
+
to the tile.
|
|
50
|
+
crop_size: sequence of int
|
|
51
|
+
The size of the cropped tile.
|
|
52
|
+
stitch_coords: sequence of int
|
|
53
|
+
Where the tile will be stitched back into an image, taking into account
|
|
54
|
+
that the tile will be cropped, in coords relative to the image.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
crop_coords: Sequence[int]
|
|
58
|
+
crop_size: Sequence[int]
|
|
59
|
+
stitch_coords: Sequence[int]
|
|
60
|
+
|
|
61
|
+
|
|
31
62
|
class PatchingStrategy(Protocol):
|
|
32
63
|
"""
|
|
33
64
|
An interface for patching strategies.
|
|
@@ -335,4 +335,8 @@ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) ->
|
|
|
335
335
|
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
336
336
|
f"and `spatial_shape={spatial_shape}`."
|
|
337
337
|
)
|
|
338
|
-
|
|
338
|
+
patches_per_dim = [
|
|
339
|
+
np.ceil(s / p) for s, p in zip(spatial_shape, patch_size, strict=False)
|
|
340
|
+
]
|
|
341
|
+
total_patches = int(np.prod(patches_per_dim))
|
|
342
|
+
return total_patches
|
|
@@ -18,13 +18,13 @@ class SequentialPatchingStrategy:
|
|
|
18
18
|
self,
|
|
19
19
|
data_shapes: Sequence[Sequence[int]],
|
|
20
20
|
patch_size: Sequence[int],
|
|
21
|
-
|
|
21
|
+
overlaps: Optional[Sequence[int]] = None,
|
|
22
22
|
):
|
|
23
23
|
self.data_shapes = data_shapes
|
|
24
24
|
self.patch_size = patch_size
|
|
25
|
-
if
|
|
26
|
-
|
|
27
|
-
self.
|
|
25
|
+
if overlaps is None:
|
|
26
|
+
overlaps = [0] * len(patch_size)
|
|
27
|
+
self.overlaps = np.asarray(overlaps)
|
|
28
28
|
|
|
29
29
|
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
30
30
|
|
|
@@ -58,7 +58,7 @@ class SequentialPatchingStrategy:
|
|
|
58
58
|
data_spatial_shape = data_shape[-len(self.patch_size) :]
|
|
59
59
|
coords_list = [
|
|
60
60
|
self._compute_coords_1d(
|
|
61
|
-
self.patch_size[i], data_spatial_shape[i], self.
|
|
61
|
+
self.patch_size[i], data_spatial_shape[i], self.overlaps[i]
|
|
62
62
|
)
|
|
63
63
|
for i in range(len(self.patch_size))
|
|
64
64
|
]
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Module for the `TilingStrategy` class."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
|
|
6
|
+
from .patching_strategy_protocol import TileSpecs
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TilingStrategy:
|
|
10
|
+
"""
|
|
11
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
12
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
13
|
+
stitch the tiles back together to create the full image.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
data_shapes: Sequence[Sequence[int]],
|
|
19
|
+
tile_size: Sequence[int],
|
|
20
|
+
overlaps: Sequence[int],
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
24
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
25
|
+
stitch the tiles back together to create the full image.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
data_shapes : sequence of (sequence of int)
|
|
30
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
31
|
+
axes SC(Z)YX.
|
|
32
|
+
tile_size : sequence of int
|
|
33
|
+
The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
|
|
34
|
+
data respectively.
|
|
35
|
+
overlaps : sequence of int
|
|
36
|
+
How much a tile will overlap with adjacent tiles in each spatial dimension.
|
|
37
|
+
"""
|
|
38
|
+
self.data_shapes = data_shapes
|
|
39
|
+
self.tile_size = tile_size
|
|
40
|
+
self.overlaps = overlaps
|
|
41
|
+
# tile_size and overlap should have same length validated in pydantic configs
|
|
42
|
+
self.tile_specs: list[TileSpecs] = self._generate_specs()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def n_patches(self) -> int:
|
|
46
|
+
"""
|
|
47
|
+
The number of patches that this patching strategy will return.
|
|
48
|
+
|
|
49
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
50
|
+
"""
|
|
51
|
+
return len(self.tile_specs)
|
|
52
|
+
|
|
53
|
+
def get_patch_spec(self, index: int) -> TileSpecs:
|
|
54
|
+
"""Return the tile specs for a given index.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
index : int
|
|
59
|
+
A patch index.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
TileSpecs
|
|
64
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
65
|
+
"""
|
|
66
|
+
return self.tile_specs[index]
|
|
67
|
+
|
|
68
|
+
def _generate_specs(self) -> list[TileSpecs]:
|
|
69
|
+
tile_specs: list[TileSpecs] = []
|
|
70
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
71
|
+
spatial_shape = data_shape[2:]
|
|
72
|
+
|
|
73
|
+
# spec info for each axis
|
|
74
|
+
axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
|
|
75
|
+
self._compute_1d_coords(
|
|
76
|
+
axis_size, self.tile_size[axis_idx], self.overlaps[axis_idx]
|
|
77
|
+
)
|
|
78
|
+
for axis_idx, axis_size in enumerate(spatial_shape)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# combine by using zip
|
|
82
|
+
all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
|
|
83
|
+
*axis_specs, strict=False
|
|
84
|
+
)
|
|
85
|
+
# patches will be the same for each sample in a stack
|
|
86
|
+
for sample_idx in range(data_shape[0]):
|
|
87
|
+
# iterate through all combinations using itertools.product
|
|
88
|
+
for coords, stitch_coords, crop_coords, crop_size in zip(
|
|
89
|
+
itertools.product(*all_coords),
|
|
90
|
+
itertools.product(*all_stitch_coords),
|
|
91
|
+
itertools.product(*all_crop_coords),
|
|
92
|
+
itertools.product(*all_crop_size),
|
|
93
|
+
strict=False,
|
|
94
|
+
):
|
|
95
|
+
tile_specs.append(
|
|
96
|
+
{
|
|
97
|
+
# PatchSpecs
|
|
98
|
+
"data_idx": data_idx,
|
|
99
|
+
"sample_idx": sample_idx,
|
|
100
|
+
"coords": coords,
|
|
101
|
+
"patch_size": self.tile_size,
|
|
102
|
+
# TileSpecs additional fields
|
|
103
|
+
"crop_coords": crop_coords,
|
|
104
|
+
"crop_size": crop_size,
|
|
105
|
+
"stitch_coords": stitch_coords,
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
return tile_specs
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _compute_1d_coords(
|
|
112
|
+
axis_size: int, tile_size: int, overlap: int
|
|
113
|
+
) -> tuple[list[int], list[int], list[int], list[int]]:
|
|
114
|
+
"""
|
|
115
|
+
Computes the TileSpec information for a single axis.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
axis_size : int
|
|
120
|
+
The size of the axis.
|
|
121
|
+
tile_size : int
|
|
122
|
+
The tile size.
|
|
123
|
+
overlap : int
|
|
124
|
+
The tile overlap.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
coords: list of int
|
|
129
|
+
The top-left (and first z-slice for 3D data) of a tile, in coords relative
|
|
130
|
+
to the image.
|
|
131
|
+
stitch_coords: list of int
|
|
132
|
+
Where the tile will be stitched back into an image, taking into account
|
|
133
|
+
that the tile will be cropped, in coords relative to the image.
|
|
134
|
+
crop_coords: list of int
|
|
135
|
+
The top-left side of where the tile will be cropped, in coordinates relative
|
|
136
|
+
to the tile.
|
|
137
|
+
crop_size: list of int
|
|
138
|
+
The size of the cropped tile.
|
|
139
|
+
"""
|
|
140
|
+
coords: list[int] = []
|
|
141
|
+
stitch_coords: list[int] = []
|
|
142
|
+
crop_coords: list[int] = []
|
|
143
|
+
crop_size: list[int] = []
|
|
144
|
+
|
|
145
|
+
step = tile_size - overlap
|
|
146
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
147
|
+
if i == 0:
|
|
148
|
+
coords.append(i)
|
|
149
|
+
crop_coords.append(0)
|
|
150
|
+
stitch_coords.append(0)
|
|
151
|
+
crop_size.append(tile_size - overlap // 2)
|
|
152
|
+
elif (i > 0) and (i + tile_size < axis_size):
|
|
153
|
+
coords.append(i)
|
|
154
|
+
crop_coords.append(overlap // 2)
|
|
155
|
+
stitch_coords.append(coords[-1] + crop_coords[-1])
|
|
156
|
+
crop_size.append(tile_size - overlap)
|
|
157
|
+
else:
|
|
158
|
+
previous_crop_size = crop_size[-1] if crop_size else 1
|
|
159
|
+
previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
|
|
160
|
+
previous_tile_end = previous_stitch_coord + previous_crop_size
|
|
161
|
+
|
|
162
|
+
coords.append(max(0, axis_size - tile_size))
|
|
163
|
+
stitch_coords.append(previous_tile_end)
|
|
164
|
+
crop_coords.append(stitch_coords[-1] - coords[-1])
|
|
165
|
+
crop_size.append(axis_size - stitch_coords[-1])
|
|
166
|
+
|
|
167
|
+
return (
|
|
168
|
+
coords,
|
|
169
|
+
stitch_coords,
|
|
170
|
+
crop_coords,
|
|
171
|
+
crop_size,
|
|
172
|
+
)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WholeSamplePatchingStrategy:
|
|
7
|
+
# TODO: warn this strategy should only be used with batch size = 1
|
|
8
|
+
# for the case of multiple image stacks with different dimensions
|
|
9
|
+
|
|
10
|
+
# TODO: docs
|
|
11
|
+
def __init__(self, data_shapes: Sequence[Sequence[int]]):
|
|
12
|
+
self.data_shapes = data_shapes
|
|
13
|
+
|
|
14
|
+
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def n_patches(self) -> int:
|
|
18
|
+
return len(self.patch_specs)
|
|
19
|
+
|
|
20
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
21
|
+
return self.patch_specs[index]
|
|
22
|
+
|
|
23
|
+
def _initialize_patch_specs(self) -> list[PatchSpecs]:
|
|
24
|
+
patch_specs: list[PatchSpecs] = []
|
|
25
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
26
|
+
spatial_shape = data_shape[2:]
|
|
27
|
+
for sample_idx in range(data_shape[0]):
|
|
28
|
+
patch_specs.append(
|
|
29
|
+
{
|
|
30
|
+
"data_idx": data_idx,
|
|
31
|
+
"sample_idx": sample_idx,
|
|
32
|
+
"coords": tuple(0 for _ in spatial_shape),
|
|
33
|
+
"patch_size": spatial_shape,
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
return patch_specs
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Next-Generation DataModules for Careamics."""
|