careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset/dataset_utils/running_stats.py +7 -3
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -335,4 +335,6 @@ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) ->
|
|
|
335
335
|
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
336
336
|
f"and `spatial_shape={spatial_shape}`."
|
|
337
337
|
)
|
|
338
|
-
|
|
338
|
+
patches_per_dim = [np.ceil(s / p) for s, p in zip(spatial_shape, patch_size)]
|
|
339
|
+
total_patches = int(np.prod(patches_per_dim))
|
|
340
|
+
return total_patches
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Module for the `TilingStrategy` class."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
|
|
6
|
+
from .patching_strategy_protocol import TileSpecs
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TilingStrategy:
|
|
10
|
+
"""
|
|
11
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
12
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
13
|
+
stitch the tiles back together to create the full image.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
data_shapes: Sequence[Sequence[int]],
|
|
19
|
+
tile_size: Sequence[int],
|
|
20
|
+
overlaps: Sequence[int],
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
24
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
25
|
+
stitch the tiles back together to create the full image.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
data_shapes : sequence of (sequence of int)
|
|
30
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
31
|
+
axes SC(Z)YX.
|
|
32
|
+
tile_size : sequence of int
|
|
33
|
+
The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
|
|
34
|
+
data respectively.
|
|
35
|
+
overlaps : sequence of int
|
|
36
|
+
How much a tile will overlap with adjacent tiles in each spatial dimension.
|
|
37
|
+
"""
|
|
38
|
+
self.data_shapes = data_shapes
|
|
39
|
+
self.tile_size = tile_size
|
|
40
|
+
self.overlaps = overlaps
|
|
41
|
+
# tile_size and overlap should have same length validated in pydantic configs
|
|
42
|
+
self.tile_specs: list[TileSpecs] = self._generate_specs()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def n_patches(self) -> int:
|
|
46
|
+
"""
|
|
47
|
+
The number of patches that this patching strategy will return.
|
|
48
|
+
|
|
49
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
50
|
+
"""
|
|
51
|
+
return len(self.tile_specs)
|
|
52
|
+
|
|
53
|
+
def get_patch_spec(self, index: int) -> TileSpecs:
|
|
54
|
+
"""Return the tile specs for a given index.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
index : int
|
|
59
|
+
A patch index.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
TileSpecs
|
|
64
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
65
|
+
"""
|
|
66
|
+
return self.tile_specs[index]
|
|
67
|
+
|
|
68
|
+
def _generate_specs(self) -> list[TileSpecs]:
|
|
69
|
+
tile_specs: list[TileSpecs] = []
|
|
70
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
71
|
+
spatial_shape = data_shape[2:]
|
|
72
|
+
|
|
73
|
+
# spec info for each axis
|
|
74
|
+
axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
|
|
75
|
+
self._compute_1d_coords(
|
|
76
|
+
axis_size, self.tile_size[axis_idx], self.overlaps[axis_idx]
|
|
77
|
+
)
|
|
78
|
+
for axis_idx, axis_size in enumerate(spatial_shape)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# combine by using zip
|
|
82
|
+
all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
|
|
83
|
+
*axis_specs
|
|
84
|
+
)
|
|
85
|
+
# patches will be the same for each sample in a stack
|
|
86
|
+
for sample_idx in range(data_shape[0]):
|
|
87
|
+
# iterate through all combinations using itertools.product
|
|
88
|
+
for coords, stitch_coords, crop_coords, crop_size in zip(
|
|
89
|
+
itertools.product(*all_coords),
|
|
90
|
+
itertools.product(*all_stitch_coords),
|
|
91
|
+
itertools.product(*all_crop_coords),
|
|
92
|
+
itertools.product(*all_crop_size),
|
|
93
|
+
):
|
|
94
|
+
tile_specs.append(
|
|
95
|
+
{
|
|
96
|
+
# PatchSpecs
|
|
97
|
+
"data_idx": data_idx,
|
|
98
|
+
"sample_idx": sample_idx,
|
|
99
|
+
"coords": coords,
|
|
100
|
+
"patch_size": self.tile_size,
|
|
101
|
+
# TileSpecs additional fields
|
|
102
|
+
"crop_coords": crop_coords,
|
|
103
|
+
"crop_size": crop_size,
|
|
104
|
+
"stitch_coords": stitch_coords,
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
return tile_specs
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _compute_1d_coords(
|
|
111
|
+
axis_size: int, tile_size: int, overlap: int
|
|
112
|
+
) -> tuple[list[int], list[int], list[int], list[int]]:
|
|
113
|
+
"""
|
|
114
|
+
Computes the TileSpec information for a single axis.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
axis_size : int
|
|
119
|
+
The size of the axis.
|
|
120
|
+
tile_size : int
|
|
121
|
+
The tile size.
|
|
122
|
+
overlap : int
|
|
123
|
+
The tile overlap.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
coords: list of int
|
|
128
|
+
The top-left (and first z-slice for 3D data) of a tile, in coords relative
|
|
129
|
+
to the image.
|
|
130
|
+
stitch_coords: list of int
|
|
131
|
+
Where the tile will be stitched back into an image, taking into account
|
|
132
|
+
that the tile will be cropped, in coords relative to the image.
|
|
133
|
+
crop_coords: list of int
|
|
134
|
+
The top-left side of where the tile will be cropped, in coordinates relative
|
|
135
|
+
to the tile.
|
|
136
|
+
crop_size: list of int
|
|
137
|
+
The size of the cropped tile.
|
|
138
|
+
"""
|
|
139
|
+
coords: list[int] = []
|
|
140
|
+
stitch_coords: list[int] = []
|
|
141
|
+
crop_coords: list[int] = []
|
|
142
|
+
crop_size: list[int] = []
|
|
143
|
+
|
|
144
|
+
step = tile_size - overlap
|
|
145
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
146
|
+
if i == 0:
|
|
147
|
+
coords.append(i)
|
|
148
|
+
crop_coords.append(0)
|
|
149
|
+
stitch_coords.append(0)
|
|
150
|
+
crop_size.append(tile_size - overlap // 2)
|
|
151
|
+
elif (i > 0) and (i + tile_size < axis_size):
|
|
152
|
+
coords.append(i)
|
|
153
|
+
crop_coords.append(overlap // 2)
|
|
154
|
+
stitch_coords.append(coords[-1] + crop_coords[-1])
|
|
155
|
+
crop_size.append(tile_size - overlap)
|
|
156
|
+
else:
|
|
157
|
+
previous_crop_size = crop_size[-1] if crop_size else 1
|
|
158
|
+
previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
|
|
159
|
+
previous_tile_end = previous_stitch_coord + previous_crop_size
|
|
160
|
+
|
|
161
|
+
coords.append(max(0, axis_size - tile_size))
|
|
162
|
+
stitch_coords.append(previous_tile_end)
|
|
163
|
+
crop_coords.append(stitch_coords[-1] - coords[-1])
|
|
164
|
+
crop_size.append(axis_size - stitch_coords[-1])
|
|
165
|
+
|
|
166
|
+
return (
|
|
167
|
+
coords,
|
|
168
|
+
stitch_coords,
|
|
169
|
+
crop_coords,
|
|
170
|
+
crop_size,
|
|
171
|
+
)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WholeSamplePatchingStrategy:
|
|
7
|
+
# TODO: warn this strategy should only be used with batch size = 1
|
|
8
|
+
# for the case of multiple image stacks with different dimensions
|
|
9
|
+
|
|
10
|
+
# TODO: docs
|
|
11
|
+
def __init__(self, data_shapes: Sequence[Sequence[int]]):
|
|
12
|
+
self.data_shapes = data_shapes
|
|
13
|
+
|
|
14
|
+
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def n_patches(self) -> int:
|
|
18
|
+
return len(self.patch_specs)
|
|
19
|
+
|
|
20
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
21
|
+
return self.patch_specs[index]
|
|
22
|
+
|
|
23
|
+
def _initialize_patch_specs(self) -> list[PatchSpecs]:
|
|
24
|
+
patch_specs: list[PatchSpecs] = []
|
|
25
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
26
|
+
spatial_shape = data_shape[2:]
|
|
27
|
+
for sample_idx in range(data_shape[0]):
|
|
28
|
+
patch_specs.append(
|
|
29
|
+
{
|
|
30
|
+
"data_idx": data_idx,
|
|
31
|
+
"sample_idx": sample_idx,
|
|
32
|
+
"coords": tuple(0 for _ in spatial_shape),
|
|
33
|
+
"patch_size": spatial_shape,
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
return patch_specs
|