careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,95 @@
1
+ """Filter using an image mask."""
2
+
3
+ import numpy as np
4
+
5
+ from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
6
+ from careamics.dataset_ng.patch_filter.coordinate_filter_protocol import (
7
+ CoordinateFilterProtocol,
8
+ )
9
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
10
+
11
+
12
+ # TODO is it more intuitive to have a negative mask? (mask of what to avoid)
13
+ class MaskCoordFilter(CoordinateFilterProtocol):
14
+ """
15
+ Filter patch coordinates based on an image mask.
16
+
17
+ Attributes
18
+ ----------
19
+ mask_extractor : PatchExtractor[GenericImageStack]
20
+ Patch extractor for the binary mask to use for filtering.
21
+ coverage_perc : float
22
+ Minimum percentage of masked pixels required to keep a patch.
23
+ p : float
24
+ Probability of applying the filter to a patch.
25
+ rng : np.random.Generator
26
+ Random number generator for stochastic filtering.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ mask_extractor: PatchExtractor[GenericImageStack],
32
+ coverage: float,
33
+ p: float = 1.0,
34
+ seed: int | None = None,
35
+ ) -> None:
36
+ """
37
+ Create a MaskCoordFilter.
38
+
39
+ This filter removes patches who fall below a threshold of masked pixels
40
+ percentage. The mask is expected to be a positive mask where masked pixels
41
+ correspond to regions of interest.
42
+
43
+ Parameters
44
+ ----------
45
+ mask_extractor : PatchExtractor[GenericImageStack]
46
+ The patch extractor for the mask used for filtering.
47
+ coverage : float
48
+ Minimum percentage of masked pixels required to keep a patch. Must be
49
+ between 0 and 1.
50
+ p : float, default=1
51
+ Probability of applying the filter to a patch. Must be between 0 and 1.
52
+ seed : int | None, default=None
53
+ Seed for the random number generator for reproducibility.
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If coverage is not between 0 and 1.
59
+ ValueError
60
+ If p is not between 0 and 1.
61
+ """
62
+
63
+ if not (0 <= coverage <= 1):
64
+ raise ValueError("Probability p must be between 0 and 1.")
65
+ if not (0 <= p <= 1):
66
+ raise ValueError("Probability p must be between 0 and 1.")
67
+
68
+ self.mask_extractor = mask_extractor
69
+ self.coverage = coverage
70
+
71
+ self.p = p
72
+ self.rng = np.random.default_rng(seed)
73
+
74
+ def filter_out(self, patch_specs: PatchSpecs) -> bool:
75
+ """
76
+ Determine whether to filter out a patch based an image mask.
77
+
78
+ Parameters
79
+ ----------
80
+ patch : PatchSpecs
81
+ The patch coordinates to evaluate.
82
+
83
+ Returns
84
+ -------
85
+ bool
86
+ True if the patch should be filtered out, False otherwise.
87
+ """
88
+
89
+ if self.rng.uniform(0, 1) < self.p:
90
+ mask_patch = self.mask_extractor.extract_patch(**patch_specs)
91
+
92
+ masked_fraction = np.sum(mask_patch) / mask_patch.size
93
+ if masked_fraction < self.coverage:
94
+ return True
95
+ return False
@@ -0,0 +1,188 @@
1
+ """Filter patch using a maximum filter."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from scipy.ndimage import maximum_filter
7
+ from tqdm import tqdm
8
+
9
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
10
+ create_array_extractor,
11
+ )
12
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
13
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
14
+ from careamics.utils import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class MaxPatchFilter(PatchFilterProtocol):
20
+ """
21
+ A patch filter based on thresholding the maximum filter of the patch.
22
+
23
+ Inspired by the CSBDeep approach.
24
+
25
+ Attributes
26
+ ----------
27
+ threshold : float
28
+ Threshold for the maximum filter of the patch.
29
+ p : float
30
+ Probability of applying the filter to a patch.
31
+ rng : np.random.Generator
32
+ Random number generator for stochastic filtering.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ threshold: float,
38
+ p: float = 1.0,
39
+ threshold_ratio: float = 0.25,
40
+ seed: int | None = None,
41
+ ) -> None:
42
+ """
43
+ Create a MaxPatchFilter.
44
+
45
+ This filter removes patches whose maximum filter valuepixels are below a
46
+ specified threshold.
47
+
48
+ Parameters
49
+ ----------
50
+ threshold : float
51
+ Threshold for the maximum filter of the patch.
52
+ p : float, default=1
53
+ Probability of applying the filter to a patch. Must be between 0 and 1.
54
+ threshold_ratio : float, default=0.25
55
+ Ratio of pixels that must be below threshold for patch to be filtered out.
56
+ Must be between 0 and 1.
57
+ seed : int | None, default=None
58
+ Seed for the random number generator for reproducibility.
59
+ """
60
+ self.threshold = threshold
61
+ self.threshold_ratio = threshold_ratio
62
+ self.p = p
63
+ self.rng = np.random.default_rng(seed)
64
+
65
+ def filter_out(self, patch: np.ndarray) -> bool:
66
+ if self.rng.uniform(0, 1) < self.p:
67
+
68
+ if np.max(patch) < self.threshold:
69
+ return True
70
+
71
+ patch_shape = [(p // 2 if p > 1 else 1) for p in patch.shape]
72
+ filtered = maximum_filter(patch, patch_shape, mode="constant")
73
+ return np.mean(filtered < self.threshold) > self.threshold_ratio
74
+
75
+ return False
76
+
77
+ @staticmethod
78
+ def filter_map(
79
+ image: np.ndarray,
80
+ patch_size: Sequence[int],
81
+ ) -> np.ndarray:
82
+ """
83
+ Compute the maximum map of an image.
84
+
85
+ The map is computed over non-overlapping patches. This method can be used
86
+ to assess a useful threshold for the MaxPatchFilter filter.
87
+
88
+ Parameters
89
+ ----------
90
+ image : numpy.NDArray
91
+ The image for which to compute the map, must be 2D or 3D.
92
+ patch_size : Sequence[int]
93
+ The size of the patches to compute the map over. Must be a sequence
94
+ of two integers.
95
+
96
+ Returns
97
+ -------
98
+ numpy.NDArray
99
+ The max map of the patch.
100
+
101
+ Raises
102
+ ------
103
+ ValueError
104
+ If the image is not 2D or 3D.
105
+
106
+ Example
107
+ -------
108
+ The `filter_map` method can be used to assess a useful threshold for the
109
+ Shannon entropy filter. Below is an example of how to compute and visualize
110
+ the Shannon entropy map of a random image and visualize thresholded versions
111
+ of the map.
112
+ >>> import numpy as np
113
+ >>> from matplotlib import pyplot as plt
114
+ >>> from careamics.dataset_ng.patch_filter import MaxPatchFilter
115
+ >>> rng = np.random.default_rng(42)
116
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
117
+ >>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
118
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
119
+ >>> patch_size = (16, 16)
120
+ >>> max_filtered = MaxPatchFilter.filter_map(image, patch_size)
121
+ >>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
122
+ >>> for i, thresh in enumerate([50 + i*5 for i in range(5)]):
123
+ ... ax[i].imshow(max_filtered >= thresh, cmap="gray") # doctest: +SKIP
124
+ ... ax[i].set_title(f"Threshold: {thresh}") # doctest: +SKIP
125
+ >>> plt.show() # doctest: +SKIP
126
+ """
127
+ if len(image.shape) < 2 or len(image.shape) > 3:
128
+ raise ValueError("Image must be 2D or 3D.")
129
+
130
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
131
+
132
+ max_filtered = np.zeros_like(image, dtype=float)
133
+
134
+ extractor = create_array_extractor(source=[image], axes=axes)
135
+ tiling = TilingStrategy(
136
+ data_shapes=[(1, 1, *image.shape)],
137
+ tile_size=patch_size,
138
+ overlaps=(0,) * len(patch_size), # no overlap
139
+ )
140
+ max_patch_size = [p // 2 for p in patch_size]
141
+
142
+ for idx in tqdm(range(tiling.n_patches), desc="Computing max map"):
143
+ patch_spec = tiling.get_patch_spec(idx)
144
+ patch = extractor.extract_patch(
145
+ data_idx=0,
146
+ sample_idx=0,
147
+ coords=patch_spec["coords"],
148
+ patch_size=patch_size,
149
+ )
150
+
151
+ coordinates = tuple(
152
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
153
+ for i, p in enumerate(patch_size)
154
+ )
155
+ max_filtered[coordinates] = maximum_filter(
156
+ patch.squeeze(), max_patch_size, mode="constant"
157
+ )
158
+
159
+ return max_filtered
160
+
161
+ @staticmethod
162
+ def apply_filter(
163
+ filter_map: np.ndarray,
164
+ threshold: float,
165
+ ) -> np.ndarray:
166
+ """
167
+ Apply the max filter to a filter map.
168
+
169
+ The filter map is the output of the `filter_map` method.
170
+
171
+ Parameters
172
+ ----------
173
+ filter_map : numpy.NDArray
174
+ The max filter map of the image.
175
+ threshold : float
176
+ The threshold to apply to the filter map.
177
+
178
+ Returns
179
+ -------
180
+ numpy.NDArray
181
+ A boolean array where True indicates that the patch should be kept
182
+ (not filtered out) and False indicates that the patch should be filtered
183
+ out.
184
+ """
185
+ threshold_map = filter_map >= threshold
186
+ coverage = np.sum(threshold_map) * 100 / threshold_map.size
187
+ logger.info(f"Image coverage: {coverage:.2f}%")
188
+ return threshold_map
@@ -0,0 +1,218 @@
1
+ """Filter using mean and standard deviation thresholds."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
9
+ create_array_extractor,
10
+ )
11
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
12
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
13
+
14
+
15
+ class MeanStdPatchFilter(PatchFilterProtocol):
16
+ """
17
+ Filter patches based on mean and standard deviation thresholds.
18
+
19
+ Attributes
20
+ ----------
21
+ mean_threshold : float
22
+ Threshold for the mean of the patch.
23
+ std_threshold : float
24
+ Threshold for the standard deviation of the patch.
25
+ p : float
26
+ Probability of applying the filter to a patch.
27
+ rng : np.random.Generator
28
+ Random number generator for stochastic filtering.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ mean_threshold: float,
34
+ std_threshold: float | None = None,
35
+ p: float = 1.0,
36
+ seed: int | None = None,
37
+ ) -> None:
38
+ """
39
+ Create a MeanStdPatchFilter.
40
+
41
+ This filter removes patches whose mean and standard deviation are both below
42
+ specified thresholds. The filtering is applied with a probability `p`, allowing
43
+ for stochastic filtering.
44
+
45
+ Parameters
46
+ ----------
47
+ mean_threshold : float
48
+ Threshold for the mean of the patch.
49
+ std_threshold : float | None, default=None
50
+ Threshold for the standard deviation of the patch. If None, then no
51
+ standard deviation filtering is applied.
52
+ p : float, default=1
53
+ Probability of applying the filter to a patch. Must be between 0 and 1.
54
+ seed : int | None, default=None
55
+ Seed for the random number generator for reproducibility.
56
+
57
+ Raises
58
+ ------
59
+ ValueError
60
+ If mean_threshold or std_threshold is negative.
61
+ ValueError
62
+ If std_threshold is negative.
63
+ ValueError
64
+ If p is not between 0 and 1.
65
+ """
66
+
67
+ if mean_threshold < 0:
68
+ raise ValueError("Mean threshold must be non-negative.")
69
+ if std_threshold is not None and std_threshold < 0:
70
+ raise ValueError("Std threshold must be non-negative.")
71
+ if not (0 <= p <= 1):
72
+ raise ValueError("Probability p must be between 0 and 1.")
73
+
74
+ self.mean_threshold = mean_threshold
75
+ self.std_threshold = std_threshold
76
+
77
+ self.p = p
78
+ self.rng = np.random.default_rng(seed)
79
+
80
+ def filter_out(self, patch: np.ndarray) -> bool:
81
+ """
82
+ Determine whether to filter out a patch based on mean and std thresholds.
83
+
84
+ Parameters
85
+ ----------
86
+ patch : numpy.NDArray
87
+ The image patch to evaluate.
88
+
89
+ Returns
90
+ -------
91
+ bool
92
+ True if the patch should be filtered out, False otherwise.
93
+ """
94
+
95
+ if self.rng.uniform(0, 1) < self.p:
96
+ patch_mean = np.mean(patch)
97
+ patch_std = np.std(patch)
98
+
99
+ return (patch_mean < self.mean_threshold) or (
100
+ self.std_threshold is not None and patch_std < self.std_threshold
101
+ )
102
+ return False
103
+
104
+ @staticmethod
105
+ def filter_map(image: np.ndarray, patch_size: Sequence[int]) -> np.ndarray:
106
+ """
107
+ Compute the mean and std map of an image.
108
+
109
+ The mean and std are computed over non-overlapping patches. This method can be
110
+ used to assess a useful threshold for the MeanStd filter.
111
+
112
+ Parameters
113
+ ----------
114
+ image : numpy.NDArray
115
+ The full image to evaluate.
116
+ patch_size : Sequence[int]
117
+ The size of the patches to consider.
118
+
119
+ Returns
120
+ -------
121
+ np.ndarray
122
+ Stacked mean and std maps of the image.
123
+
124
+ Raises
125
+ ------
126
+ ValueError
127
+ If the image is not 2D or 3D.
128
+
129
+ Example
130
+ -------
131
+ The `filter_map` method can be used to assess useful thresholds for the
132
+ MeanStd filter.
133
+ >>> import numpy as np
134
+ >>> import matplotlib.pyplot as plt
135
+ >>> from careamics.dataset_ng.patch_filter import MeanStdPatchFilter
136
+ >>> rng = np.random.default_rng(42)
137
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
138
+ >>> image[64:192, 64:192] = rng.normal(50, 3, (128, 128))
139
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
140
+ >>> patch_size = (16, 16)
141
+ >>> meanstd_map = MeanStdPatchFilter.filter_map(image, patch_size)
142
+ >>> fig, ax = plt.subplots(3, 3, figsize=(10, 10)) # doctest: +SKIP
143
+ >>> for i, mean_thresh in enumerate([48 + i for i in range(3)]):
144
+ ... for j, std_thresh in enumerate([5 + i for i in range(3)]):
145
+ ... ax[i, j].imshow(
146
+ ... (meanstd_map[0, ...] > mean_thresh)
147
+ ... & (meanstd_map[1, ...] > std_thresh),
148
+ ... cmap="gray", vmin=0, vmax=1
149
+ ... ) # doctest: +SKIP
150
+ ... ax[i, j].set_title(
151
+ ... f"Mean: {mean_thresh}, Std: {std_thresh}"
152
+ ... ) # doctest: +SKIP
153
+ >>> plt.show() # doctest: +SKIP
154
+ """
155
+ if len(image.shape) < 2 or len(image.shape) > 3:
156
+ raise ValueError("Image must be 2D or 3D.")
157
+
158
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
159
+
160
+ mean = np.zeros_like(image, dtype=float)
161
+ std = np.zeros_like(image, dtype=float)
162
+
163
+ extractor = create_array_extractor(source=[image], axes=axes)
164
+ tiling = TilingStrategy(
165
+ data_shapes=[(1, 1, *image.shape)],
166
+ tile_size=patch_size,
167
+ overlaps=(0,) * len(patch_size), # no overlap
168
+ )
169
+
170
+ for idx in tqdm(range(tiling.n_patches), desc="Computing Mean/STD map"):
171
+ patch_spec = tiling.get_patch_spec(idx)
172
+ patch = extractor.extract_patch(
173
+ data_idx=0,
174
+ sample_idx=0,
175
+ coords=patch_spec["coords"],
176
+ patch_size=patch_size,
177
+ )
178
+
179
+ coordinates = tuple(
180
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
181
+ for i, p in enumerate(patch_size)
182
+ )
183
+ mean[coordinates] = np.mean(patch)
184
+ std[coordinates] = np.std(patch)
185
+
186
+ return np.stack([mean, std], axis=0)
187
+
188
+ @staticmethod
189
+ def apply_filter(
190
+ filter_map: np.ndarray,
191
+ mean_threshold: float,
192
+ std_threshold: float | None = None,
193
+ ) -> np.ndarray:
194
+ """
195
+ Apply mean and std thresholds to a filter map.
196
+
197
+ The filter map is the output of the `filter_map` method.
198
+
199
+ Parameters
200
+ ----------
201
+ filter_map : np.ndarray
202
+ Stacked mean and std maps of the image.
203
+ mean_threshold : float
204
+ Threshold for the mean of the patch.
205
+ std_threshold : float | None, default=None
206
+ Threshold for the standard deviation of the patch. If None, then no
207
+ standard deviation filtering is applied.
208
+
209
+ Returns
210
+ -------
211
+ np.ndarray
212
+ A binary map where True indicates patches that pass the filter.
213
+ """
214
+ if std_threshold is not None:
215
+ return (filter_map[0, ...] > mean_threshold) & (
216
+ filter_map[1, ...] > std_threshold
217
+ )
218
+ return filter_map[0, ...] > mean_threshold
@@ -0,0 +1,50 @@
1
+ """A protocol for patch filtering."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Protocol
5
+
6
+ import numpy as np
7
+
8
+
9
+ class PatchFilterProtocol(Protocol):
10
+ """
11
+ An interface for implementing patch filtering strategies.
12
+ """
13
+
14
+ def filter_out(self, patch: np.ndarray) -> bool:
15
+ """
16
+ Determine whether to filter out a given patch.
17
+
18
+ Parameters
19
+ ----------
20
+ patch : numpy.NDArray
21
+ The image patch to evaluate.
22
+
23
+ Returns
24
+ -------
25
+ bool
26
+ True if the patch should be filtered out (excluded), False otherwise.
27
+ """
28
+ ...
29
+
30
+ @staticmethod
31
+ def filter_map(
32
+ image: np.ndarray,
33
+ patch_size: Sequence[int],
34
+ ) -> np.ndarray:
35
+ """
36
+ Compute a filter map for the entire image based on the patch filtering criteria.
37
+
38
+ Parameters
39
+ ----------
40
+ image : numpy.NDArray
41
+ The full image to evaluate.
42
+ patch_size : Sequence[int]
43
+ The size of the patches to consider.
44
+
45
+ Returns
46
+ -------
47
+ numpy.NDArray
48
+ A map where each element is the .
49
+ """
50
+ ...