careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Filter using an image mask."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
|
|
6
|
+
from careamics.dataset_ng.patch_filter.coordinate_filter_protocol import (
|
|
7
|
+
CoordinateFilterProtocol,
|
|
8
|
+
)
|
|
9
|
+
from careamics.dataset_ng.patching_strategies import PatchSpecs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO is it more intuitive to have a negative mask? (mask of what to avoid)
|
|
13
|
+
class MaskCoordFilter(CoordinateFilterProtocol):
|
|
14
|
+
"""
|
|
15
|
+
Filter patch coordinates based on an image mask.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
mask_extractor : PatchExtractor[GenericImageStack]
|
|
20
|
+
Patch extractor for the binary mask to use for filtering.
|
|
21
|
+
coverage_perc : float
|
|
22
|
+
Minimum percentage of masked pixels required to keep a patch.
|
|
23
|
+
p : float
|
|
24
|
+
Probability of applying the filter to a patch.
|
|
25
|
+
rng : np.random.Generator
|
|
26
|
+
Random number generator for stochastic filtering.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
mask_extractor: PatchExtractor[GenericImageStack],
|
|
32
|
+
coverage: float,
|
|
33
|
+
p: float = 1.0,
|
|
34
|
+
seed: int | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Create a MaskCoordFilter.
|
|
38
|
+
|
|
39
|
+
This filter removes patches who fall below a threshold of masked pixels
|
|
40
|
+
percentage. The mask is expected to be a positive mask where masked pixels
|
|
41
|
+
correspond to regions of interest.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
mask_extractor : PatchExtractor[GenericImageStack]
|
|
46
|
+
The patch extractor for the mask used for filtering.
|
|
47
|
+
coverage : float
|
|
48
|
+
Minimum percentage of masked pixels required to keep a patch. Must be
|
|
49
|
+
between 0 and 1.
|
|
50
|
+
p : float, default=1
|
|
51
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
52
|
+
seed : int | None, default=None
|
|
53
|
+
Seed for the random number generator for reproducibility.
|
|
54
|
+
|
|
55
|
+
Raises
|
|
56
|
+
------
|
|
57
|
+
ValueError
|
|
58
|
+
If coverage is not between 0 and 1.
|
|
59
|
+
ValueError
|
|
60
|
+
If p is not between 0 and 1.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
if not (0 <= coverage <= 1):
|
|
64
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
65
|
+
if not (0 <= p <= 1):
|
|
66
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
67
|
+
|
|
68
|
+
self.mask_extractor = mask_extractor
|
|
69
|
+
self.coverage = coverage
|
|
70
|
+
|
|
71
|
+
self.p = p
|
|
72
|
+
self.rng = np.random.default_rng(seed)
|
|
73
|
+
|
|
74
|
+
def filter_out(self, patch_specs: PatchSpecs) -> bool:
|
|
75
|
+
"""
|
|
76
|
+
Determine whether to filter out a patch based an image mask.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
patch : PatchSpecs
|
|
81
|
+
The patch coordinates to evaluate.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
bool
|
|
86
|
+
True if the patch should be filtered out, False otherwise.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
90
|
+
mask_patch = self.mask_extractor.extract_patch(**patch_specs)
|
|
91
|
+
|
|
92
|
+
masked_fraction = np.sum(mask_patch) / mask_patch.size
|
|
93
|
+
if masked_fraction < self.coverage:
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Filter patch using a maximum filter."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.ndimage import maximum_filter
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
10
|
+
create_array_extractor,
|
|
11
|
+
)
|
|
12
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
13
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
14
|
+
from careamics.utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MaxPatchFilter(PatchFilterProtocol):
|
|
20
|
+
"""
|
|
21
|
+
A patch filter based on thresholding the maximum filter of the patch.
|
|
22
|
+
|
|
23
|
+
Inspired by the CSBDeep approach.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
threshold : float
|
|
28
|
+
Threshold for the maximum filter of the patch.
|
|
29
|
+
p : float
|
|
30
|
+
Probability of applying the filter to a patch.
|
|
31
|
+
rng : np.random.Generator
|
|
32
|
+
Random number generator for stochastic filtering.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
threshold: float,
|
|
38
|
+
p: float = 1.0,
|
|
39
|
+
threshold_ratio: float = 0.25,
|
|
40
|
+
seed: int | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Create a MaxPatchFilter.
|
|
44
|
+
|
|
45
|
+
This filter removes patches whose maximum filter valuepixels are below a
|
|
46
|
+
specified threshold.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
threshold : float
|
|
51
|
+
Threshold for the maximum filter of the patch.
|
|
52
|
+
p : float, default=1
|
|
53
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
54
|
+
threshold_ratio : float, default=0.25
|
|
55
|
+
Ratio of pixels that must be below threshold for patch to be filtered out.
|
|
56
|
+
Must be between 0 and 1.
|
|
57
|
+
seed : int | None, default=None
|
|
58
|
+
Seed for the random number generator for reproducibility.
|
|
59
|
+
"""
|
|
60
|
+
self.threshold = threshold
|
|
61
|
+
self.threshold_ratio = threshold_ratio
|
|
62
|
+
self.p = p
|
|
63
|
+
self.rng = np.random.default_rng(seed)
|
|
64
|
+
|
|
65
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
66
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
67
|
+
|
|
68
|
+
if np.max(patch) < self.threshold:
|
|
69
|
+
return True
|
|
70
|
+
|
|
71
|
+
patch_shape = [(p // 2 if p > 1 else 1) for p in patch.shape]
|
|
72
|
+
filtered = maximum_filter(patch, patch_shape, mode="constant")
|
|
73
|
+
return np.mean(filtered < self.threshold) > self.threshold_ratio
|
|
74
|
+
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def filter_map(
|
|
79
|
+
image: np.ndarray,
|
|
80
|
+
patch_size: Sequence[int],
|
|
81
|
+
) -> np.ndarray:
|
|
82
|
+
"""
|
|
83
|
+
Compute the maximum map of an image.
|
|
84
|
+
|
|
85
|
+
The map is computed over non-overlapping patches. This method can be used
|
|
86
|
+
to assess a useful threshold for the MaxPatchFilter filter.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
image : numpy.NDArray
|
|
91
|
+
The image for which to compute the map, must be 2D or 3D.
|
|
92
|
+
patch_size : Sequence[int]
|
|
93
|
+
The size of the patches to compute the map over. Must be a sequence
|
|
94
|
+
of two integers.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
numpy.NDArray
|
|
99
|
+
The max map of the patch.
|
|
100
|
+
|
|
101
|
+
Raises
|
|
102
|
+
------
|
|
103
|
+
ValueError
|
|
104
|
+
If the image is not 2D or 3D.
|
|
105
|
+
|
|
106
|
+
Example
|
|
107
|
+
-------
|
|
108
|
+
The `filter_map` method can be used to assess a useful threshold for the
|
|
109
|
+
Shannon entropy filter. Below is an example of how to compute and visualize
|
|
110
|
+
the Shannon entropy map of a random image and visualize thresholded versions
|
|
111
|
+
of the map.
|
|
112
|
+
>>> import numpy as np
|
|
113
|
+
>>> from matplotlib import pyplot as plt
|
|
114
|
+
>>> from careamics.dataset_ng.patch_filter import MaxPatchFilter
|
|
115
|
+
>>> rng = np.random.default_rng(42)
|
|
116
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
117
|
+
>>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
|
|
118
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
119
|
+
>>> patch_size = (16, 16)
|
|
120
|
+
>>> max_filtered = MaxPatchFilter.filter_map(image, patch_size)
|
|
121
|
+
>>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
|
|
122
|
+
>>> for i, thresh in enumerate([50 + i*5 for i in range(5)]):
|
|
123
|
+
... ax[i].imshow(max_filtered >= thresh, cmap="gray") # doctest: +SKIP
|
|
124
|
+
... ax[i].set_title(f"Threshold: {thresh}") # doctest: +SKIP
|
|
125
|
+
>>> plt.show() # doctest: +SKIP
|
|
126
|
+
"""
|
|
127
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
128
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
129
|
+
|
|
130
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
131
|
+
|
|
132
|
+
max_filtered = np.zeros_like(image, dtype=float)
|
|
133
|
+
|
|
134
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
135
|
+
tiling = TilingStrategy(
|
|
136
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
137
|
+
tile_size=patch_size,
|
|
138
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
139
|
+
)
|
|
140
|
+
max_patch_size = [p // 2 for p in patch_size]
|
|
141
|
+
|
|
142
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing max map"):
|
|
143
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
144
|
+
patch = extractor.extract_patch(
|
|
145
|
+
data_idx=0,
|
|
146
|
+
sample_idx=0,
|
|
147
|
+
coords=patch_spec["coords"],
|
|
148
|
+
patch_size=patch_size,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
coordinates = tuple(
|
|
152
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
153
|
+
for i, p in enumerate(patch_size)
|
|
154
|
+
)
|
|
155
|
+
max_filtered[coordinates] = maximum_filter(
|
|
156
|
+
patch.squeeze(), max_patch_size, mode="constant"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return max_filtered
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def apply_filter(
|
|
163
|
+
filter_map: np.ndarray,
|
|
164
|
+
threshold: float,
|
|
165
|
+
) -> np.ndarray:
|
|
166
|
+
"""
|
|
167
|
+
Apply the max filter to a filter map.
|
|
168
|
+
|
|
169
|
+
The filter map is the output of the `filter_map` method.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
filter_map : numpy.NDArray
|
|
174
|
+
The max filter map of the image.
|
|
175
|
+
threshold : float
|
|
176
|
+
The threshold to apply to the filter map.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
numpy.NDArray
|
|
181
|
+
A boolean array where True indicates that the patch should be kept
|
|
182
|
+
(not filtered out) and False indicates that the patch should be filtered
|
|
183
|
+
out.
|
|
184
|
+
"""
|
|
185
|
+
threshold_map = filter_map >= threshold
|
|
186
|
+
coverage = np.sum(threshold_map) * 100 / threshold_map.size
|
|
187
|
+
logger.info(f"Image coverage: {coverage:.2f}%")
|
|
188
|
+
return threshold_map
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Filter using mean and standard deviation thresholds."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
9
|
+
create_array_extractor,
|
|
10
|
+
)
|
|
11
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
12
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MeanStdPatchFilter(PatchFilterProtocol):
|
|
16
|
+
"""
|
|
17
|
+
Filter patches based on mean and standard deviation thresholds.
|
|
18
|
+
|
|
19
|
+
Attributes
|
|
20
|
+
----------
|
|
21
|
+
mean_threshold : float
|
|
22
|
+
Threshold for the mean of the patch.
|
|
23
|
+
std_threshold : float
|
|
24
|
+
Threshold for the standard deviation of the patch.
|
|
25
|
+
p : float
|
|
26
|
+
Probability of applying the filter to a patch.
|
|
27
|
+
rng : np.random.Generator
|
|
28
|
+
Random number generator for stochastic filtering.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
mean_threshold: float,
|
|
34
|
+
std_threshold: float | None = None,
|
|
35
|
+
p: float = 1.0,
|
|
36
|
+
seed: int | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Create a MeanStdPatchFilter.
|
|
40
|
+
|
|
41
|
+
This filter removes patches whose mean and standard deviation are both below
|
|
42
|
+
specified thresholds. The filtering is applied with a probability `p`, allowing
|
|
43
|
+
for stochastic filtering.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
mean_threshold : float
|
|
48
|
+
Threshold for the mean of the patch.
|
|
49
|
+
std_threshold : float | None, default=None
|
|
50
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
51
|
+
standard deviation filtering is applied.
|
|
52
|
+
p : float, default=1
|
|
53
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
54
|
+
seed : int | None, default=None
|
|
55
|
+
Seed for the random number generator for reproducibility.
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
ValueError
|
|
60
|
+
If mean_threshold or std_threshold is negative.
|
|
61
|
+
ValueError
|
|
62
|
+
If std_threshold is negative.
|
|
63
|
+
ValueError
|
|
64
|
+
If p is not between 0 and 1.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
if mean_threshold < 0:
|
|
68
|
+
raise ValueError("Mean threshold must be non-negative.")
|
|
69
|
+
if std_threshold is not None and std_threshold < 0:
|
|
70
|
+
raise ValueError("Std threshold must be non-negative.")
|
|
71
|
+
if not (0 <= p <= 1):
|
|
72
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
73
|
+
|
|
74
|
+
self.mean_threshold = mean_threshold
|
|
75
|
+
self.std_threshold = std_threshold
|
|
76
|
+
|
|
77
|
+
self.p = p
|
|
78
|
+
self.rng = np.random.default_rng(seed)
|
|
79
|
+
|
|
80
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
81
|
+
"""
|
|
82
|
+
Determine whether to filter out a patch based on mean and std thresholds.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
patch : numpy.NDArray
|
|
87
|
+
The image patch to evaluate.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
bool
|
|
92
|
+
True if the patch should be filtered out, False otherwise.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
96
|
+
patch_mean = np.mean(patch)
|
|
97
|
+
patch_std = np.std(patch)
|
|
98
|
+
|
|
99
|
+
return (patch_mean < self.mean_threshold) or (
|
|
100
|
+
self.std_threshold is not None and patch_std < self.std_threshold
|
|
101
|
+
)
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def filter_map(image: np.ndarray, patch_size: Sequence[int]) -> np.ndarray:
|
|
106
|
+
"""
|
|
107
|
+
Compute the mean and std map of an image.
|
|
108
|
+
|
|
109
|
+
The mean and std are computed over non-overlapping patches. This method can be
|
|
110
|
+
used to assess a useful threshold for the MeanStd filter.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
image : numpy.NDArray
|
|
115
|
+
The full image to evaluate.
|
|
116
|
+
patch_size : Sequence[int]
|
|
117
|
+
The size of the patches to consider.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
np.ndarray
|
|
122
|
+
Stacked mean and std maps of the image.
|
|
123
|
+
|
|
124
|
+
Raises
|
|
125
|
+
------
|
|
126
|
+
ValueError
|
|
127
|
+
If the image is not 2D or 3D.
|
|
128
|
+
|
|
129
|
+
Example
|
|
130
|
+
-------
|
|
131
|
+
The `filter_map` method can be used to assess useful thresholds for the
|
|
132
|
+
MeanStd filter.
|
|
133
|
+
>>> import numpy as np
|
|
134
|
+
>>> import matplotlib.pyplot as plt
|
|
135
|
+
>>> from careamics.dataset_ng.patch_filter import MeanStdPatchFilter
|
|
136
|
+
>>> rng = np.random.default_rng(42)
|
|
137
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
138
|
+
>>> image[64:192, 64:192] = rng.normal(50, 3, (128, 128))
|
|
139
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
140
|
+
>>> patch_size = (16, 16)
|
|
141
|
+
>>> meanstd_map = MeanStdPatchFilter.filter_map(image, patch_size)
|
|
142
|
+
>>> fig, ax = plt.subplots(3, 3, figsize=(10, 10)) # doctest: +SKIP
|
|
143
|
+
>>> for i, mean_thresh in enumerate([48 + i for i in range(3)]):
|
|
144
|
+
... for j, std_thresh in enumerate([5 + i for i in range(3)]):
|
|
145
|
+
... ax[i, j].imshow(
|
|
146
|
+
... (meanstd_map[0, ...] > mean_thresh)
|
|
147
|
+
... & (meanstd_map[1, ...] > std_thresh),
|
|
148
|
+
... cmap="gray", vmin=0, vmax=1
|
|
149
|
+
... ) # doctest: +SKIP
|
|
150
|
+
... ax[i, j].set_title(
|
|
151
|
+
... f"Mean: {mean_thresh}, Std: {std_thresh}"
|
|
152
|
+
... ) # doctest: +SKIP
|
|
153
|
+
>>> plt.show() # doctest: +SKIP
|
|
154
|
+
"""
|
|
155
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
156
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
157
|
+
|
|
158
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
159
|
+
|
|
160
|
+
mean = np.zeros_like(image, dtype=float)
|
|
161
|
+
std = np.zeros_like(image, dtype=float)
|
|
162
|
+
|
|
163
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
164
|
+
tiling = TilingStrategy(
|
|
165
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
166
|
+
tile_size=patch_size,
|
|
167
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing Mean/STD map"):
|
|
171
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
172
|
+
patch = extractor.extract_patch(
|
|
173
|
+
data_idx=0,
|
|
174
|
+
sample_idx=0,
|
|
175
|
+
coords=patch_spec["coords"],
|
|
176
|
+
patch_size=patch_size,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
coordinates = tuple(
|
|
180
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
181
|
+
for i, p in enumerate(patch_size)
|
|
182
|
+
)
|
|
183
|
+
mean[coordinates] = np.mean(patch)
|
|
184
|
+
std[coordinates] = np.std(patch)
|
|
185
|
+
|
|
186
|
+
return np.stack([mean, std], axis=0)
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def apply_filter(
|
|
190
|
+
filter_map: np.ndarray,
|
|
191
|
+
mean_threshold: float,
|
|
192
|
+
std_threshold: float | None = None,
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""
|
|
195
|
+
Apply mean and std thresholds to a filter map.
|
|
196
|
+
|
|
197
|
+
The filter map is the output of the `filter_map` method.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
filter_map : np.ndarray
|
|
202
|
+
Stacked mean and std maps of the image.
|
|
203
|
+
mean_threshold : float
|
|
204
|
+
Threshold for the mean of the patch.
|
|
205
|
+
std_threshold : float | None, default=None
|
|
206
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
207
|
+
standard deviation filtering is applied.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
np.ndarray
|
|
212
|
+
A binary map where True indicates patches that pass the filter.
|
|
213
|
+
"""
|
|
214
|
+
if std_threshold is not None:
|
|
215
|
+
return (filter_map[0, ...] > mean_threshold) & (
|
|
216
|
+
filter_map[1, ...] > std_threshold
|
|
217
|
+
)
|
|
218
|
+
return filter_map[0, ...] > mean_threshold
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""A protocol for patch filtering."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PatchFilterProtocol(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
An interface for implementing patch filtering strategies.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
15
|
+
"""
|
|
16
|
+
Determine whether to filter out a given patch.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
patch : numpy.NDArray
|
|
21
|
+
The image patch to evaluate.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
bool
|
|
26
|
+
True if the patch should be filtered out (excluded), False otherwise.
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def filter_map(
|
|
32
|
+
image: np.ndarray,
|
|
33
|
+
patch_size: Sequence[int],
|
|
34
|
+
) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Compute a filter map for the entire image based on the patch filtering criteria.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
image : numpy.NDArray
|
|
41
|
+
The full image to evaluate.
|
|
42
|
+
patch_size : Sequence[int]
|
|
43
|
+
The size of the patches to consider.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
numpy.NDArray
|
|
48
|
+
A map where each element is the .
|
|
49
|
+
"""
|
|
50
|
+
...
|