careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +7 -4
- careamics/config/configuration.py +6 -55
- careamics/config/configuration_factories.py +22 -12
- careamics/config/data/data_model.py +49 -9
- careamics/config/data/ng_data_model.py +167 -2
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/lightning/callbacks/data_stats_callback.py +13 -3
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +4 -3
- careamics/lightning/microsplit_data_module.py +15 -10
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/models/lvae/likelihoods.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +3 -2
- careamics/prediction_utils/stitch_prediction.py +17 -6
- careamics/utils/version.py +4 -4
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Filter patch using a maximum filter."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.ndimage import maximum_filter
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
10
|
+
create_array_extractor,
|
|
11
|
+
)
|
|
12
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
13
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
14
|
+
from careamics.utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MaxPatchFilter(PatchFilterProtocol):
|
|
20
|
+
"""
|
|
21
|
+
A patch filter based on thresholding the maximum filter of the patch.
|
|
22
|
+
|
|
23
|
+
Inspired by the CSBDeep approach.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
threshold : float
|
|
28
|
+
Threshold for the maximum filter of the patch.
|
|
29
|
+
p : float
|
|
30
|
+
Probability of applying the filter to a patch.
|
|
31
|
+
rng : np.random.Generator
|
|
32
|
+
Random number generator for stochastic filtering.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
threshold: float,
|
|
38
|
+
p: float = 1.0,
|
|
39
|
+
threshold_ratio: float = 0.25,
|
|
40
|
+
seed: int | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Create a MaxPatchFilter.
|
|
44
|
+
|
|
45
|
+
This filter removes patches whose maximum filter valuepixels are below a
|
|
46
|
+
specified threshold.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
threshold : float
|
|
51
|
+
Threshold for the maximum filter of the patch.
|
|
52
|
+
p : float, default=1
|
|
53
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
54
|
+
threshold_ratio : float, default=0.25
|
|
55
|
+
Ratio of pixels that must be below threshold for patch to be filtered out.
|
|
56
|
+
Must be between 0 and 1.
|
|
57
|
+
seed : int | None, default=None
|
|
58
|
+
Seed for the random number generator for reproducibility.
|
|
59
|
+
"""
|
|
60
|
+
self.threshold = threshold
|
|
61
|
+
self.threshold_ratio = threshold_ratio
|
|
62
|
+
self.p = p
|
|
63
|
+
self.rng = np.random.default_rng(seed)
|
|
64
|
+
|
|
65
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
66
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
67
|
+
|
|
68
|
+
if np.max(patch) < self.threshold:
|
|
69
|
+
return True
|
|
70
|
+
|
|
71
|
+
patch_shape = [(p // 2 if p > 1 else 1) for p in patch.shape]
|
|
72
|
+
filtered = maximum_filter(patch, patch_shape, mode="constant")
|
|
73
|
+
return np.mean(filtered < self.threshold) > self.threshold_ratio
|
|
74
|
+
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def filter_map(
|
|
79
|
+
image: np.ndarray,
|
|
80
|
+
patch_size: Sequence[int],
|
|
81
|
+
) -> np.ndarray:
|
|
82
|
+
"""
|
|
83
|
+
Compute the maximum map of an image.
|
|
84
|
+
|
|
85
|
+
The map is computed over non-overlapping patches. This method can be used
|
|
86
|
+
to assess a useful threshold for the MaxPatchFilter filter.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
image : numpy.NDArray
|
|
91
|
+
The image for which to compute the map, must be 2D or 3D.
|
|
92
|
+
patch_size : Sequence[int]
|
|
93
|
+
The size of the patches to compute the map over. Must be a sequence
|
|
94
|
+
of two integers.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
numpy.NDArray
|
|
99
|
+
The max map of the patch.
|
|
100
|
+
|
|
101
|
+
Raises
|
|
102
|
+
------
|
|
103
|
+
ValueError
|
|
104
|
+
If the image is not 2D or 3D.
|
|
105
|
+
|
|
106
|
+
Example
|
|
107
|
+
-------
|
|
108
|
+
The `filter_map` method can be used to assess a useful threshold for the
|
|
109
|
+
Shannon entropy filter. Below is an example of how to compute and visualize
|
|
110
|
+
the Shannon entropy map of a random image and visualize thresholded versions
|
|
111
|
+
of the map.
|
|
112
|
+
>>> import numpy as np
|
|
113
|
+
>>> from matplotlib import pyplot as plt
|
|
114
|
+
>>> from careamics.dataset_ng.patch_filter import MaxPatchFilter
|
|
115
|
+
>>> rng = np.random.default_rng(42)
|
|
116
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
117
|
+
>>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
|
|
118
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
119
|
+
>>> patch_size = (16, 16)
|
|
120
|
+
>>> max_filtered = MaxPatchFilter.filter_map(image, patch_size)
|
|
121
|
+
>>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
|
|
122
|
+
>>> for i, thresh in enumerate([50 + i*5 for i in range(5)]):
|
|
123
|
+
... ax[i].imshow(max_filtered >= thresh, cmap="gray") # doctest: +SKIP
|
|
124
|
+
... ax[i].set_title(f"Threshold: {thresh}") # doctest: +SKIP
|
|
125
|
+
>>> plt.show() # doctest: +SKIP
|
|
126
|
+
"""
|
|
127
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
128
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
129
|
+
|
|
130
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
131
|
+
|
|
132
|
+
max_filtered = np.zeros_like(image, dtype=float)
|
|
133
|
+
|
|
134
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
135
|
+
tiling = TilingStrategy(
|
|
136
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
137
|
+
tile_size=patch_size,
|
|
138
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
139
|
+
)
|
|
140
|
+
max_patch_size = [p // 2 for p in patch_size]
|
|
141
|
+
|
|
142
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing max map"):
|
|
143
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
144
|
+
patch = extractor.extract_patch(
|
|
145
|
+
data_idx=0,
|
|
146
|
+
sample_idx=0,
|
|
147
|
+
coords=patch_spec["coords"],
|
|
148
|
+
patch_size=patch_size,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
coordinates = tuple(
|
|
152
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
153
|
+
for i, p in enumerate(patch_size)
|
|
154
|
+
)
|
|
155
|
+
max_filtered[coordinates] = maximum_filter(
|
|
156
|
+
patch.squeeze(), max_patch_size, mode="constant"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return max_filtered
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def apply_filter(
|
|
163
|
+
filter_map: np.ndarray,
|
|
164
|
+
threshold: float,
|
|
165
|
+
) -> np.ndarray:
|
|
166
|
+
"""
|
|
167
|
+
Apply the max filter to a filter map.
|
|
168
|
+
|
|
169
|
+
The filter map is the output of the `filter_map` method.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
filter_map : numpy.NDArray
|
|
174
|
+
The max filter map of the image.
|
|
175
|
+
threshold : float
|
|
176
|
+
The threshold to apply to the filter map.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
numpy.NDArray
|
|
181
|
+
A boolean array where True indicates that the patch should be kept
|
|
182
|
+
(not filtered out) and False indicates that the patch should be filtered
|
|
183
|
+
out.
|
|
184
|
+
"""
|
|
185
|
+
threshold_map = filter_map >= threshold
|
|
186
|
+
coverage = np.sum(threshold_map) * 100 / threshold_map.size
|
|
187
|
+
logger.info(f"Image coverage: {coverage:.2f}%")
|
|
188
|
+
return threshold_map
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Filter using mean and standard deviation thresholds."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
9
|
+
create_array_extractor,
|
|
10
|
+
)
|
|
11
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
12
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MeanStdPatchFilter(PatchFilterProtocol):
|
|
16
|
+
"""
|
|
17
|
+
Filter patches based on mean and standard deviation thresholds.
|
|
18
|
+
|
|
19
|
+
Attributes
|
|
20
|
+
----------
|
|
21
|
+
mean_threshold : float
|
|
22
|
+
Threshold for the mean of the patch.
|
|
23
|
+
std_threshold : float
|
|
24
|
+
Threshold for the standard deviation of the patch.
|
|
25
|
+
p : float
|
|
26
|
+
Probability of applying the filter to a patch.
|
|
27
|
+
rng : np.random.Generator
|
|
28
|
+
Random number generator for stochastic filtering.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
mean_threshold: float,
|
|
34
|
+
std_threshold: float | None = None,
|
|
35
|
+
p: float = 1.0,
|
|
36
|
+
seed: int | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Create a MeanStdPatchFilter.
|
|
40
|
+
|
|
41
|
+
This filter removes patches whose mean and standard deviation are both below
|
|
42
|
+
specified thresholds. The filtering is applied with a probability `p`, allowing
|
|
43
|
+
for stochastic filtering.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
mean_threshold : float
|
|
48
|
+
Threshold for the mean of the patch.
|
|
49
|
+
std_threshold : float | None, default=None
|
|
50
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
51
|
+
standard deviation filtering is applied.
|
|
52
|
+
p : float, default=1
|
|
53
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
54
|
+
seed : int | None, default=None
|
|
55
|
+
Seed for the random number generator for reproducibility.
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
ValueError
|
|
60
|
+
If mean_threshold or std_threshold is negative.
|
|
61
|
+
ValueError
|
|
62
|
+
If std_threshold is negative.
|
|
63
|
+
ValueError
|
|
64
|
+
If p is not between 0 and 1.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
if mean_threshold < 0:
|
|
68
|
+
raise ValueError("Mean threshold must be non-negative.")
|
|
69
|
+
if std_threshold is not None and std_threshold < 0:
|
|
70
|
+
raise ValueError("Std threshold must be non-negative.")
|
|
71
|
+
if not (0 <= p <= 1):
|
|
72
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
73
|
+
|
|
74
|
+
self.mean_threshold = mean_threshold
|
|
75
|
+
self.std_threshold = std_threshold
|
|
76
|
+
|
|
77
|
+
self.p = p
|
|
78
|
+
self.rng = np.random.default_rng(seed)
|
|
79
|
+
|
|
80
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
81
|
+
"""
|
|
82
|
+
Determine whether to filter out a patch based on mean and std thresholds.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
patch : numpy.NDArray
|
|
87
|
+
The image patch to evaluate.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
bool
|
|
92
|
+
True if the patch should be filtered out, False otherwise.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
96
|
+
patch_mean = np.mean(patch)
|
|
97
|
+
patch_std = np.std(patch)
|
|
98
|
+
|
|
99
|
+
return (patch_mean < self.mean_threshold) or (
|
|
100
|
+
self.std_threshold is not None and patch_std < self.std_threshold
|
|
101
|
+
)
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def filter_map(image: np.ndarray, patch_size: Sequence[int]) -> np.ndarray:
|
|
106
|
+
"""
|
|
107
|
+
Compute the mean and std map of an image.
|
|
108
|
+
|
|
109
|
+
The mean and std are computed over non-overlapping patches. This method can be
|
|
110
|
+
used to assess a useful threshold for the MeanStd filter.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
image : numpy.NDArray
|
|
115
|
+
The full image to evaluate.
|
|
116
|
+
patch_size : Sequence[int]
|
|
117
|
+
The size of the patches to consider.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
np.ndarray
|
|
122
|
+
Stacked mean and std maps of the image.
|
|
123
|
+
|
|
124
|
+
Raises
|
|
125
|
+
------
|
|
126
|
+
ValueError
|
|
127
|
+
If the image is not 2D or 3D.
|
|
128
|
+
|
|
129
|
+
Example
|
|
130
|
+
-------
|
|
131
|
+
The `filter_map` method can be used to assess useful thresholds for the
|
|
132
|
+
MeanStd filter.
|
|
133
|
+
>>> import numpy as np
|
|
134
|
+
>>> import matplotlib.pyplot as plt
|
|
135
|
+
>>> from careamics.dataset_ng.patch_filter import MeanStdPatchFilter
|
|
136
|
+
>>> rng = np.random.default_rng(42)
|
|
137
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
138
|
+
>>> image[64:192, 64:192] = rng.normal(50, 3, (128, 128))
|
|
139
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
140
|
+
>>> patch_size = (16, 16)
|
|
141
|
+
>>> meanstd_map = MeanStdPatchFilter.filter_map(image, patch_size)
|
|
142
|
+
>>> fig, ax = plt.subplots(3, 3, figsize=(10, 10)) # doctest: +SKIP
|
|
143
|
+
>>> for i, mean_thresh in enumerate([48 + i for i in range(3)]):
|
|
144
|
+
... for j, std_thresh in enumerate([5 + i for i in range(3)]):
|
|
145
|
+
... ax[i, j].imshow(
|
|
146
|
+
... (meanstd_map[0, ...] > mean_thresh)
|
|
147
|
+
... & (meanstd_map[1, ...] > std_thresh),
|
|
148
|
+
... cmap="gray", vmin=0, vmax=1
|
|
149
|
+
... ) # doctest: +SKIP
|
|
150
|
+
... ax[i, j].set_title(
|
|
151
|
+
... f"Mean: {mean_thresh}, Std: {std_thresh}"
|
|
152
|
+
... ) # doctest: +SKIP
|
|
153
|
+
>>> plt.show() # doctest: +SKIP
|
|
154
|
+
"""
|
|
155
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
156
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
157
|
+
|
|
158
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
159
|
+
|
|
160
|
+
mean = np.zeros_like(image, dtype=float)
|
|
161
|
+
std = np.zeros_like(image, dtype=float)
|
|
162
|
+
|
|
163
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
164
|
+
tiling = TilingStrategy(
|
|
165
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
166
|
+
tile_size=patch_size,
|
|
167
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing Mean/STD map"):
|
|
171
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
172
|
+
patch = extractor.extract_patch(
|
|
173
|
+
data_idx=0,
|
|
174
|
+
sample_idx=0,
|
|
175
|
+
coords=patch_spec["coords"],
|
|
176
|
+
patch_size=patch_size,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
coordinates = tuple(
|
|
180
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
181
|
+
for i, p in enumerate(patch_size)
|
|
182
|
+
)
|
|
183
|
+
mean[coordinates] = np.mean(patch)
|
|
184
|
+
std[coordinates] = np.std(patch)
|
|
185
|
+
|
|
186
|
+
return np.stack([mean, std], axis=0)
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def apply_filter(
|
|
190
|
+
filter_map: np.ndarray,
|
|
191
|
+
mean_threshold: float,
|
|
192
|
+
std_threshold: float | None = None,
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""
|
|
195
|
+
Apply mean and std thresholds to a filter map.
|
|
196
|
+
|
|
197
|
+
The filter map is the output of the `filter_map` method.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
filter_map : np.ndarray
|
|
202
|
+
Stacked mean and std maps of the image.
|
|
203
|
+
mean_threshold : float
|
|
204
|
+
Threshold for the mean of the patch.
|
|
205
|
+
std_threshold : float | None, default=None
|
|
206
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
207
|
+
standard deviation filtering is applied.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
np.ndarray
|
|
212
|
+
A binary map where True indicates patches that pass the filter.
|
|
213
|
+
"""
|
|
214
|
+
if std_threshold is not None:
|
|
215
|
+
return (filter_map[0, ...] > mean_threshold) & (
|
|
216
|
+
filter_map[1, ...] > std_threshold
|
|
217
|
+
)
|
|
218
|
+
return filter_map[0, ...] > mean_threshold
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""A protocol for patch filtering."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PatchFilterProtocol(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
An interface for implementing patch filtering strategies.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
15
|
+
"""
|
|
16
|
+
Determine whether to filter out a given patch.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
patch : numpy.NDArray
|
|
21
|
+
The image patch to evaluate.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
bool
|
|
26
|
+
True if the patch should be filtered out (excluded), False otherwise.
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def filter_map(
|
|
32
|
+
image: np.ndarray,
|
|
33
|
+
patch_size: Sequence[int],
|
|
34
|
+
) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Compute a filter map for the entire image based on the patch filtering criteria.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
image : numpy.NDArray
|
|
41
|
+
The full image to evaluate.
|
|
42
|
+
patch_size : Sequence[int]
|
|
43
|
+
The size of the patches to consider.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
numpy.NDArray
|
|
48
|
+
A map where each element is the .
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Filter patches based on Shannon entropy threshold."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from skimage.measure import shannon_entropy
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
10
|
+
create_array_extractor,
|
|
11
|
+
)
|
|
12
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
13
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ShannonPatchFilter(PatchFilterProtocol):
|
|
17
|
+
"""
|
|
18
|
+
Filter patches based on Shannon entropy threshold.
|
|
19
|
+
|
|
20
|
+
Attributes
|
|
21
|
+
----------
|
|
22
|
+
threshold : float
|
|
23
|
+
Threshold for the Shannon entropy of the patch.
|
|
24
|
+
p : float
|
|
25
|
+
Probability of applying the filter to a patch.
|
|
26
|
+
rng : np.random.Generator
|
|
27
|
+
Random number generator for stochastic filtering.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, threshold: float, p: float = 1.0, seed: int | None = None
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Create a ShannonEntropyFilter.
|
|
35
|
+
|
|
36
|
+
This filter removes patches whose Shannon entropy is below a specified
|
|
37
|
+
threshold.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
threshold : float
|
|
42
|
+
Threshold for the Shannon entropy of the patch.
|
|
43
|
+
p : float, default=1
|
|
44
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
45
|
+
seed : int | None, default=None
|
|
46
|
+
Seed for the random number generator for reproducibility.
|
|
47
|
+
|
|
48
|
+
Raises
|
|
49
|
+
------
|
|
50
|
+
ValueError
|
|
51
|
+
If threshold is negative.
|
|
52
|
+
ValueError
|
|
53
|
+
If p is not between 0 and 1.
|
|
54
|
+
"""
|
|
55
|
+
if threshold < 0:
|
|
56
|
+
raise ValueError("Threshold must be non-negative.")
|
|
57
|
+
if not (0 <= p <= 1):
|
|
58
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
59
|
+
|
|
60
|
+
self.threshold = threshold
|
|
61
|
+
|
|
62
|
+
self.p = p
|
|
63
|
+
self.rng = np.random.default_rng(seed)
|
|
64
|
+
|
|
65
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
66
|
+
"""
|
|
67
|
+
Determine whether to filter out a patch based on its Shannon entropy.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
patch : numpy.NDArray
|
|
72
|
+
The patch to evaluate.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
bool
|
|
77
|
+
True if the patch should be filtered out, False otherwise.
|
|
78
|
+
"""
|
|
79
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
80
|
+
return shannon_entropy(patch) < self.threshold
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def filter_map(
|
|
85
|
+
image: np.ndarray,
|
|
86
|
+
patch_size: Sequence[int],
|
|
87
|
+
) -> np.ndarray:
|
|
88
|
+
"""
|
|
89
|
+
Compute the Shannon entropy map of an image.
|
|
90
|
+
|
|
91
|
+
The entropy is computed over non-overlapping patches. This method can be used
|
|
92
|
+
to assess a useful threshold for the Shannon entropy filter.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
image : numpy.NDArray
|
|
97
|
+
The image for which to compute the entropy map, must be 2D or 3D.
|
|
98
|
+
patch_size : Sequence[int]
|
|
99
|
+
The size of the patches to compute the entropy over. Must be a sequence
|
|
100
|
+
of two integers.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
numpy.NDArray
|
|
105
|
+
The Shannon entropy map of the patch.
|
|
106
|
+
|
|
107
|
+
Raises
|
|
108
|
+
------
|
|
109
|
+
ValueError
|
|
110
|
+
If the image is not 2D or 3D.
|
|
111
|
+
|
|
112
|
+
Example
|
|
113
|
+
-------
|
|
114
|
+
The `filter_map` method can be used to assess a useful threshold for the
|
|
115
|
+
Shannon entropy filter. Below is an example of how to compute and visualize
|
|
116
|
+
the Shannon entropy map of a random image and visualize thresholded versions
|
|
117
|
+
of the map.
|
|
118
|
+
>>> import numpy as np
|
|
119
|
+
>>> from matplotlib import pyplot as plt
|
|
120
|
+
>>> from careamics.dataset_ng.patch_filter import ShannonPatchFilter
|
|
121
|
+
>>> rng = np.random.default_rng(42)
|
|
122
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
123
|
+
>>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
|
|
124
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
125
|
+
>>> patch_size = (16, 16)
|
|
126
|
+
>>> entropy_map = ShannonPatchFilter.filter_map(image, patch_size)
|
|
127
|
+
>>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
|
|
128
|
+
>>> for i, thresh in enumerate([2 + 1.5 * i for i in range(5)]):
|
|
129
|
+
... ax[i].imshow(entropy_map >= thresh, cmap="gray") #doctest: +SKIP
|
|
130
|
+
... ax[i].set_title(f"Threshold: {thresh}") #doctest: +SKIP
|
|
131
|
+
>>> plt.show() # doctest: +SKIP
|
|
132
|
+
"""
|
|
133
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
134
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
135
|
+
|
|
136
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
137
|
+
|
|
138
|
+
shannon_img = np.zeros_like(image, dtype=float)
|
|
139
|
+
|
|
140
|
+
extractor = create_array_extractor(source=[image], axes=axes)
|
|
141
|
+
tiling = TilingStrategy(
|
|
142
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
143
|
+
tile_size=patch_size,
|
|
144
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing Shannon Entropy map"):
|
|
148
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
149
|
+
patch = extractor.extract_patch(
|
|
150
|
+
data_idx=0,
|
|
151
|
+
sample_idx=0,
|
|
152
|
+
coords=patch_spec["coords"],
|
|
153
|
+
patch_size=patch_size,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
coordinates = tuple(
|
|
157
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
158
|
+
for i, p in enumerate(patch_size)
|
|
159
|
+
)
|
|
160
|
+
shannon_img[coordinates] = shannon_entropy(patch)
|
|
161
|
+
|
|
162
|
+
return shannon_img
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def apply_filter(
|
|
166
|
+
filter_map: np.ndarray,
|
|
167
|
+
threshold: float,
|
|
168
|
+
) -> np.ndarray:
|
|
169
|
+
"""
|
|
170
|
+
Apply the Shannon entropy filter to a precomputed filter map.
|
|
171
|
+
|
|
172
|
+
The filter map is the output of the `filter_map` method.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
filter_map : numpy.NDArray
|
|
177
|
+
The precomputed Shannon entropy map of the image.
|
|
178
|
+
threshold : float
|
|
179
|
+
The Shannon entropy threshold for filtering.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
numpy.NDArray
|
|
184
|
+
A boolean array where True indicates that the patch should be kept
|
|
185
|
+
(not filtered out) and False indicates that the patch should be filtered
|
|
186
|
+
out.
|
|
187
|
+
"""
|
|
188
|
+
return filter_map >= threshold
|
|
@@ -7,12 +7,22 @@ from pytorch_lightning.callbacks import Callback
|
|
|
7
7
|
class DataStatsCallback(Callback):
|
|
8
8
|
"""Callback to update model's data statistics from datamodule.
|
|
9
9
|
|
|
10
|
-
This callback ensures that the model has access to the data statistics (mean and
|
|
11
|
-
calculated by the datamodule before training starts.
|
|
10
|
+
This callback ensures that the model has access to the data statistics (mean and
|
|
11
|
+
std) calculated by the datamodule before training starts.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
14
|
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
|
|
15
|
-
"""Called when trainer is setting up.
|
|
15
|
+
"""Called when trainer is setting up.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
trainer : Lightning.Trainer
|
|
20
|
+
The trainer instance.
|
|
21
|
+
module : Lightning.LightningModule
|
|
22
|
+
The model being trained.
|
|
23
|
+
stage : str
|
|
24
|
+
The current stage of training (e.g., 'fit', 'validate', 'test', 'predict').
|
|
25
|
+
"""
|
|
16
26
|
if stage == "fit":
|
|
17
27
|
# Get data statistics from datamodule
|
|
18
28
|
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
|