careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Patch filtering strategies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CoordinateFilterProtocol",
|
|
5
|
+
"MaskCoordFilter",
|
|
6
|
+
"MaxPatchFilter",
|
|
7
|
+
"MeanStdPatchFilter",
|
|
8
|
+
"PatchFilterProtocol",
|
|
9
|
+
"ShannonPatchFilter",
|
|
10
|
+
"create_coord_filter",
|
|
11
|
+
"create_patch_filter",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .coordinate_filter_protocol import CoordinateFilterProtocol
|
|
15
|
+
from .filter_factory import create_coord_filter, create_patch_filter
|
|
16
|
+
from .mask_filter import MaskCoordFilter
|
|
17
|
+
from .max_filter import MaxPatchFilter
|
|
18
|
+
from .mean_std_filter import MeanStdPatchFilter
|
|
19
|
+
from .patch_filter_protocol import PatchFilterProtocol
|
|
20
|
+
from .shannon_filter import ShannonPatchFilter
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""A protocol for patch filtering."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from careamics.dataset_ng.patching_strategies import PatchSpecs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CoordinateFilterProtocol(Protocol):
|
|
9
|
+
"""
|
|
10
|
+
An interface for implementing coordinate filtering strategies.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def filter_out(self, patch: PatchSpecs) -> bool:
|
|
14
|
+
"""
|
|
15
|
+
Determine whether to filter out a given patch based on its coordinates.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
patch : PatchSpecs
|
|
20
|
+
The patch coordinates to evaluate.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
bool
|
|
25
|
+
True if the patch should be filtered out (excluded), False otherwise.
|
|
26
|
+
"""
|
|
27
|
+
...
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Factories for coordinate and patch filters."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from careamics.config.data.patch_filter import (
|
|
6
|
+
FilterConfig,
|
|
7
|
+
MaskFilterConfig,
|
|
8
|
+
MaxFilterConfig,
|
|
9
|
+
MeanSTDFilterConfig,
|
|
10
|
+
ShannonFilterConfig,
|
|
11
|
+
)
|
|
12
|
+
from careamics.config.support.supported_filters import (
|
|
13
|
+
SupportedCoordinateFilters,
|
|
14
|
+
SupportedPatchFilters,
|
|
15
|
+
)
|
|
16
|
+
from careamics.dataset_ng.image_stack import GenericImageStack
|
|
17
|
+
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
18
|
+
|
|
19
|
+
from .mask_filter import MaskCoordFilter
|
|
20
|
+
from .max_filter import MaxPatchFilter
|
|
21
|
+
from .mean_std_filter import MeanStdPatchFilter
|
|
22
|
+
from .shannon_filter import ShannonPatchFilter
|
|
23
|
+
|
|
24
|
+
PatchFilter = Union[
|
|
25
|
+
MaxPatchFilter,
|
|
26
|
+
MeanStdPatchFilter,
|
|
27
|
+
ShannonPatchFilter,
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
CoordFilter = Union[MaskCoordFilter]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_coord_filter(
|
|
35
|
+
filter_model: FilterConfig, mask: PatchExtractor[GenericImageStack]
|
|
36
|
+
) -> CoordFilter:
|
|
37
|
+
"""Factory function to create coordinate filter instances based on the filter name.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
filter_model : FilterModel
|
|
42
|
+
Pydantic model of the filter to be created.
|
|
43
|
+
mask : PatchExtractor[GenericImageStack]
|
|
44
|
+
Mask extractor to be used for the mask filter.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
CoordFilter
|
|
49
|
+
Instance of the mask patch filter.
|
|
50
|
+
"""
|
|
51
|
+
if filter_model.name == SupportedCoordinateFilters.MASK:
|
|
52
|
+
assert isinstance(filter_model, MaskFilterConfig)
|
|
53
|
+
return MaskCoordFilter(
|
|
54
|
+
mask_extractor=mask,
|
|
55
|
+
coverage=filter_model.coverage,
|
|
56
|
+
p=filter_model.p,
|
|
57
|
+
seed=filter_model.seed,
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Unknown filter name: {filter_model}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def create_patch_filter(filter_model: FilterConfig) -> PatchFilter:
|
|
64
|
+
"""Factory function to create patch filter instances based on the filter name.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
filter_model : FilterModel
|
|
69
|
+
Pydantic model of the filter to be created.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
PatchFilter
|
|
74
|
+
Instance of the requested patch filter.
|
|
75
|
+
"""
|
|
76
|
+
if filter_model.name == SupportedPatchFilters.MAX:
|
|
77
|
+
assert isinstance(filter_model, MaxFilterConfig)
|
|
78
|
+
return MaxPatchFilter(
|
|
79
|
+
threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
|
|
80
|
+
)
|
|
81
|
+
elif filter_model.name == SupportedPatchFilters.MEANSTD:
|
|
82
|
+
assert isinstance(filter_model, MeanSTDFilterConfig)
|
|
83
|
+
return MeanStdPatchFilter(
|
|
84
|
+
mean_threshold=filter_model.mean_threshold,
|
|
85
|
+
std_threshold=filter_model.std_threshold,
|
|
86
|
+
p=filter_model.p,
|
|
87
|
+
seed=filter_model.seed,
|
|
88
|
+
)
|
|
89
|
+
elif filter_model.name == SupportedPatchFilters.SHANNON:
|
|
90
|
+
assert isinstance(filter_model, ShannonFilterConfig)
|
|
91
|
+
return ShannonPatchFilter(
|
|
92
|
+
threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unknown filter name: {filter_model}")
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Filter using an image mask."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from careamics.dataset_ng.image_stack import GenericImageStack
|
|
6
|
+
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
7
|
+
from careamics.dataset_ng.patch_filter.coordinate_filter_protocol import (
|
|
8
|
+
CoordinateFilterProtocol,
|
|
9
|
+
)
|
|
10
|
+
from careamics.dataset_ng.patching_strategies import PatchSpecs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO is it more intuitive to have a negative mask? (mask of what to avoid)
|
|
14
|
+
class MaskCoordFilter(CoordinateFilterProtocol):
|
|
15
|
+
"""
|
|
16
|
+
Filter patch coordinates based on an image mask.
|
|
17
|
+
|
|
18
|
+
Attributes
|
|
19
|
+
----------
|
|
20
|
+
mask_extractor : PatchExtractor[GenericImageStack]
|
|
21
|
+
Patch extractor for the binary mask to use for filtering.
|
|
22
|
+
coverage_perc : float
|
|
23
|
+
Minimum percentage of masked pixels required to keep a patch.
|
|
24
|
+
p : float
|
|
25
|
+
Probability of applying the filter to a patch.
|
|
26
|
+
rng : np.random.Generator
|
|
27
|
+
Random number generator for stochastic filtering.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
mask_extractor: PatchExtractor[GenericImageStack],
|
|
33
|
+
coverage: float,
|
|
34
|
+
p: float = 1.0,
|
|
35
|
+
seed: int | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Create a MaskCoordFilter.
|
|
39
|
+
|
|
40
|
+
This filter removes patches who fall below a threshold of masked pixels
|
|
41
|
+
percentage. The mask is expected to be a positive mask where masked pixels
|
|
42
|
+
correspond to regions of interest.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
mask_extractor : PatchExtractor[GenericImageStack]
|
|
47
|
+
The patch extractor for the mask used for filtering.
|
|
48
|
+
coverage : float
|
|
49
|
+
Minimum percentage of masked pixels required to keep a patch. Must be
|
|
50
|
+
between 0 and 1.
|
|
51
|
+
p : float, default=1
|
|
52
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
53
|
+
seed : int | None, default=None
|
|
54
|
+
Seed for the random number generator for reproducibility.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
ValueError
|
|
59
|
+
If coverage is not between 0 and 1.
|
|
60
|
+
ValueError
|
|
61
|
+
If p is not between 0 and 1.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
if not (0 <= coverage <= 1):
|
|
65
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
66
|
+
if not (0 <= p <= 1):
|
|
67
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
68
|
+
|
|
69
|
+
self.mask_extractor = mask_extractor
|
|
70
|
+
self.coverage = coverage
|
|
71
|
+
|
|
72
|
+
self.p = p
|
|
73
|
+
self.rng = np.random.default_rng(seed)
|
|
74
|
+
|
|
75
|
+
def filter_out(self, patch_specs: PatchSpecs) -> bool:
|
|
76
|
+
"""
|
|
77
|
+
Determine whether to filter out a patch based an image mask.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
patch : PatchSpecs
|
|
82
|
+
The patch coordinates to evaluate.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
bool
|
|
87
|
+
True if the patch should be filtered out, False otherwise.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
91
|
+
mask_patch = self.mask_extractor.extract_patch(**patch_specs)
|
|
92
|
+
|
|
93
|
+
masked_fraction = np.sum(mask_patch) / mask_patch.size
|
|
94
|
+
if masked_fraction < self.coverage:
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Filter patch using a maximum filter."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.ndimage import maximum_filter
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from careamics.dataset_ng.image_stack_loader import load_arrays
|
|
10
|
+
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
11
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
12
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
13
|
+
from careamics.utils import get_logger
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MaxPatchFilter(PatchFilterProtocol):
|
|
19
|
+
"""
|
|
20
|
+
A patch filter based on thresholding the maximum filter of the patch.
|
|
21
|
+
|
|
22
|
+
Inspired by the CSBDeep approach.
|
|
23
|
+
|
|
24
|
+
Attributes
|
|
25
|
+
----------
|
|
26
|
+
threshold : float
|
|
27
|
+
Threshold for the maximum filter of the patch.
|
|
28
|
+
p : float
|
|
29
|
+
Probability of applying the filter to a patch.
|
|
30
|
+
rng : np.random.Generator
|
|
31
|
+
Random number generator for stochastic filtering.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
threshold: float,
|
|
37
|
+
p: float = 1.0,
|
|
38
|
+
threshold_ratio: float = 0.25,
|
|
39
|
+
seed: int | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Create a MaxPatchFilter.
|
|
43
|
+
|
|
44
|
+
This filter removes patches whose maximum filter valuepixels are below a
|
|
45
|
+
specified threshold.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
threshold : float
|
|
50
|
+
Threshold for the maximum filter of the patch.
|
|
51
|
+
p : float, default=1
|
|
52
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
53
|
+
threshold_ratio : float, default=0.25
|
|
54
|
+
Ratio of pixels that must be below threshold for patch to be filtered out.
|
|
55
|
+
Must be between 0 and 1.
|
|
56
|
+
seed : int | None, default=None
|
|
57
|
+
Seed for the random number generator for reproducibility.
|
|
58
|
+
"""
|
|
59
|
+
self.threshold = threshold
|
|
60
|
+
self.threshold_ratio = threshold_ratio
|
|
61
|
+
self.p = p
|
|
62
|
+
self.rng = np.random.default_rng(seed)
|
|
63
|
+
|
|
64
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
65
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
66
|
+
|
|
67
|
+
if np.max(patch) < self.threshold:
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
patch_shape = [(p // 2 if p > 1 else 1) for p in patch.shape]
|
|
71
|
+
filtered = maximum_filter(patch, patch_shape, mode="constant")
|
|
72
|
+
return np.mean(filtered < self.threshold) > self.threshold_ratio
|
|
73
|
+
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def filter_map(
|
|
78
|
+
image: np.ndarray,
|
|
79
|
+
patch_size: Sequence[int],
|
|
80
|
+
) -> np.ndarray:
|
|
81
|
+
"""
|
|
82
|
+
Compute the maximum map of an image.
|
|
83
|
+
|
|
84
|
+
The map is computed over non-overlapping patches. This method can be used
|
|
85
|
+
to assess a useful threshold for the MaxPatchFilter filter.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
image : numpy.NDArray
|
|
90
|
+
The image for which to compute the map, must be 2D or 3D.
|
|
91
|
+
patch_size : Sequence[int]
|
|
92
|
+
The size of the patches to compute the map over. Must be a sequence
|
|
93
|
+
of two integers.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
numpy.NDArray
|
|
98
|
+
The max map of the patch.
|
|
99
|
+
|
|
100
|
+
Raises
|
|
101
|
+
------
|
|
102
|
+
ValueError
|
|
103
|
+
If the image is not 2D or 3D.
|
|
104
|
+
|
|
105
|
+
Example
|
|
106
|
+
-------
|
|
107
|
+
The `filter_map` method can be used to assess a useful threshold for the
|
|
108
|
+
Shannon entropy filter. Below is an example of how to compute and visualize
|
|
109
|
+
the Shannon entropy map of a random image and visualize thresholded versions
|
|
110
|
+
of the map.
|
|
111
|
+
>>> import numpy as np
|
|
112
|
+
>>> from matplotlib import pyplot as plt
|
|
113
|
+
>>> from careamics.dataset_ng.patch_filter import MaxPatchFilter
|
|
114
|
+
>>> rng = np.random.default_rng(42)
|
|
115
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
116
|
+
>>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
|
|
117
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
118
|
+
>>> patch_size = (16, 16)
|
|
119
|
+
>>> max_filtered = MaxPatchFilter.filter_map(image, patch_size)
|
|
120
|
+
>>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
|
|
121
|
+
>>> for i, thresh in enumerate([50 + i*5 for i in range(5)]):
|
|
122
|
+
... ax[i].imshow(max_filtered >= thresh, cmap="gray") # doctest: +SKIP
|
|
123
|
+
... ax[i].set_title(f"Threshold: {thresh}") # doctest: +SKIP
|
|
124
|
+
>>> plt.show() # doctest: +SKIP
|
|
125
|
+
"""
|
|
126
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
127
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
128
|
+
|
|
129
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
130
|
+
|
|
131
|
+
max_filtered = np.zeros_like(image, dtype=float)
|
|
132
|
+
|
|
133
|
+
image_stacks = load_arrays(source=[image], axes=axes)
|
|
134
|
+
extractor = PatchExtractor(image_stacks)
|
|
135
|
+
tiling = TilingStrategy(
|
|
136
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
137
|
+
patch_size=patch_size,
|
|
138
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
139
|
+
)
|
|
140
|
+
max_patch_size = [p // 2 for p in patch_size]
|
|
141
|
+
|
|
142
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing max map"):
|
|
143
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
144
|
+
patch = extractor.extract_patch(
|
|
145
|
+
data_idx=0,
|
|
146
|
+
sample_idx=0,
|
|
147
|
+
coords=patch_spec["coords"],
|
|
148
|
+
patch_size=patch_size,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
coordinates = tuple(
|
|
152
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
153
|
+
for i, p in enumerate(patch_size)
|
|
154
|
+
)
|
|
155
|
+
max_filtered[coordinates] = maximum_filter(
|
|
156
|
+
patch.squeeze(), max_patch_size, mode="constant"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return max_filtered
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def apply_filter(
|
|
163
|
+
filter_map: np.ndarray,
|
|
164
|
+
threshold: float,
|
|
165
|
+
) -> np.ndarray:
|
|
166
|
+
"""
|
|
167
|
+
Apply the max filter to a filter map.
|
|
168
|
+
|
|
169
|
+
The filter map is the output of the `filter_map` method.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
filter_map : numpy.NDArray
|
|
174
|
+
The max filter map of the image.
|
|
175
|
+
threshold : float
|
|
176
|
+
The threshold to apply to the filter map.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
numpy.NDArray
|
|
181
|
+
A boolean array where True indicates that the patch should be kept
|
|
182
|
+
(not filtered out) and False indicates that the patch should be filtered
|
|
183
|
+
out.
|
|
184
|
+
"""
|
|
185
|
+
threshold_map = filter_map >= threshold
|
|
186
|
+
coverage = np.sum(threshold_map) * 100 / threshold_map.size
|
|
187
|
+
logger.info(f"Image coverage: {coverage:.2f}%")
|
|
188
|
+
return threshold_map
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Filter using mean and standard deviation thresholds."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from careamics.dataset_ng.image_stack_loader import load_arrays
|
|
9
|
+
from careamics.dataset_ng.patch_extractor import PatchExtractor
|
|
10
|
+
from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
|
|
11
|
+
from careamics.dataset_ng.patching_strategies import TilingStrategy
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MeanStdPatchFilter(PatchFilterProtocol):
|
|
15
|
+
"""
|
|
16
|
+
Filter patches based on mean and standard deviation thresholds.
|
|
17
|
+
|
|
18
|
+
Attributes
|
|
19
|
+
----------
|
|
20
|
+
mean_threshold : float
|
|
21
|
+
Threshold for the mean of the patch.
|
|
22
|
+
std_threshold : float
|
|
23
|
+
Threshold for the standard deviation of the patch.
|
|
24
|
+
p : float
|
|
25
|
+
Probability of applying the filter to a patch.
|
|
26
|
+
rng : np.random.Generator
|
|
27
|
+
Random number generator for stochastic filtering.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
mean_threshold: float,
|
|
33
|
+
std_threshold: float | None = None,
|
|
34
|
+
p: float = 1.0,
|
|
35
|
+
seed: int | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Create a MeanStdPatchFilter.
|
|
39
|
+
|
|
40
|
+
This filter removes patches whose mean and standard deviation are both below
|
|
41
|
+
specified thresholds. The filtering is applied with a probability `p`, allowing
|
|
42
|
+
for stochastic filtering.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
mean_threshold : float
|
|
47
|
+
Threshold for the mean of the patch.
|
|
48
|
+
std_threshold : float | None, default=None
|
|
49
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
50
|
+
standard deviation filtering is applied.
|
|
51
|
+
p : float, default=1
|
|
52
|
+
Probability of applying the filter to a patch. Must be between 0 and 1.
|
|
53
|
+
seed : int | None, default=None
|
|
54
|
+
Seed for the random number generator for reproducibility.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
ValueError
|
|
59
|
+
If mean_threshold or std_threshold is negative.
|
|
60
|
+
ValueError
|
|
61
|
+
If std_threshold is negative.
|
|
62
|
+
ValueError
|
|
63
|
+
If p is not between 0 and 1.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
if mean_threshold < 0:
|
|
67
|
+
raise ValueError("Mean threshold must be non-negative.")
|
|
68
|
+
if std_threshold is not None and std_threshold < 0:
|
|
69
|
+
raise ValueError("Std threshold must be non-negative.")
|
|
70
|
+
if not (0 <= p <= 1):
|
|
71
|
+
raise ValueError("Probability p must be between 0 and 1.")
|
|
72
|
+
|
|
73
|
+
self.mean_threshold = mean_threshold
|
|
74
|
+
self.std_threshold = std_threshold
|
|
75
|
+
|
|
76
|
+
self.p = p
|
|
77
|
+
self.rng = np.random.default_rng(seed)
|
|
78
|
+
|
|
79
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Determine whether to filter out a patch based on mean and std thresholds.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
patch : numpy.NDArray
|
|
86
|
+
The image patch to evaluate.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
bool
|
|
91
|
+
True if the patch should be filtered out, False otherwise.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
if self.rng.uniform(0, 1) < self.p:
|
|
95
|
+
patch_mean = np.mean(patch)
|
|
96
|
+
patch_std = np.std(patch)
|
|
97
|
+
|
|
98
|
+
return (patch_mean < self.mean_threshold) or (
|
|
99
|
+
self.std_threshold is not None and patch_std < self.std_threshold
|
|
100
|
+
)
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def filter_map(image: np.ndarray, patch_size: Sequence[int]) -> np.ndarray:
|
|
105
|
+
"""
|
|
106
|
+
Compute the mean and std map of an image.
|
|
107
|
+
|
|
108
|
+
The mean and std are computed over non-overlapping patches. This method can be
|
|
109
|
+
used to assess a useful threshold for the MeanStd filter.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
image : numpy.NDArray
|
|
114
|
+
The full image to evaluate.
|
|
115
|
+
patch_size : Sequence[int]
|
|
116
|
+
The size of the patches to consider.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
np.ndarray
|
|
121
|
+
Stacked mean and std maps of the image.
|
|
122
|
+
|
|
123
|
+
Raises
|
|
124
|
+
------
|
|
125
|
+
ValueError
|
|
126
|
+
If the image is not 2D or 3D.
|
|
127
|
+
|
|
128
|
+
Example
|
|
129
|
+
-------
|
|
130
|
+
The `filter_map` method can be used to assess useful thresholds for the
|
|
131
|
+
MeanStd filter.
|
|
132
|
+
>>> import numpy as np
|
|
133
|
+
>>> import matplotlib.pyplot as plt
|
|
134
|
+
>>> from careamics.dataset_ng.patch_filter import MeanStdPatchFilter
|
|
135
|
+
>>> rng = np.random.default_rng(42)
|
|
136
|
+
>>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
|
|
137
|
+
>>> image[64:192, 64:192] = rng.normal(50, 3, (128, 128))
|
|
138
|
+
>>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
|
|
139
|
+
>>> patch_size = (16, 16)
|
|
140
|
+
>>> meanstd_map = MeanStdPatchFilter.filter_map(image, patch_size)
|
|
141
|
+
>>> fig, ax = plt.subplots(3, 3, figsize=(10, 10)) # doctest: +SKIP
|
|
142
|
+
>>> for i, mean_thresh in enumerate([48 + i for i in range(3)]):
|
|
143
|
+
... for j, std_thresh in enumerate([5 + i for i in range(3)]):
|
|
144
|
+
... ax[i, j].imshow(
|
|
145
|
+
... (meanstd_map[0, ...] > mean_thresh)
|
|
146
|
+
... & (meanstd_map[1, ...] > std_thresh),
|
|
147
|
+
... cmap="gray", vmin=0, vmax=1
|
|
148
|
+
... ) # doctest: +SKIP
|
|
149
|
+
... ax[i, j].set_title(
|
|
150
|
+
... f"Mean: {mean_thresh}, Std: {std_thresh}"
|
|
151
|
+
... ) # doctest: +SKIP
|
|
152
|
+
>>> plt.show() # doctest: +SKIP
|
|
153
|
+
"""
|
|
154
|
+
if len(image.shape) < 2 or len(image.shape) > 3:
|
|
155
|
+
raise ValueError("Image must be 2D or 3D.")
|
|
156
|
+
|
|
157
|
+
axes = "YX" if len(patch_size) == 2 else "ZYX"
|
|
158
|
+
|
|
159
|
+
mean = np.zeros_like(image, dtype=float)
|
|
160
|
+
std = np.zeros_like(image, dtype=float)
|
|
161
|
+
|
|
162
|
+
image_stacks = load_arrays(source=[image], axes=axes)
|
|
163
|
+
extractor = PatchExtractor(image_stacks)
|
|
164
|
+
tiling = TilingStrategy(
|
|
165
|
+
data_shapes=[(1, 1, *image.shape)],
|
|
166
|
+
patch_size=patch_size,
|
|
167
|
+
overlaps=(0,) * len(patch_size), # no overlap
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
for idx in tqdm(range(tiling.n_patches), desc="Computing Mean/STD map"):
|
|
171
|
+
patch_spec = tiling.get_patch_spec(idx)
|
|
172
|
+
patch = extractor.extract_patch(
|
|
173
|
+
data_idx=0,
|
|
174
|
+
sample_idx=0,
|
|
175
|
+
coords=patch_spec["coords"],
|
|
176
|
+
patch_size=patch_size,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
coordinates = tuple(
|
|
180
|
+
slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
|
|
181
|
+
for i, p in enumerate(patch_size)
|
|
182
|
+
)
|
|
183
|
+
mean[coordinates] = np.mean(patch)
|
|
184
|
+
std[coordinates] = np.std(patch)
|
|
185
|
+
|
|
186
|
+
return np.stack([mean, std], axis=0)
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def apply_filter(
|
|
190
|
+
filter_map: np.ndarray,
|
|
191
|
+
mean_threshold: float,
|
|
192
|
+
std_threshold: float | None = None,
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""
|
|
195
|
+
Apply mean and std thresholds to a filter map.
|
|
196
|
+
|
|
197
|
+
The filter map is the output of the `filter_map` method.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
filter_map : np.ndarray
|
|
202
|
+
Stacked mean and std maps of the image.
|
|
203
|
+
mean_threshold : float
|
|
204
|
+
Threshold for the mean of the patch.
|
|
205
|
+
std_threshold : float | None, default=None
|
|
206
|
+
Threshold for the standard deviation of the patch. If None, then no
|
|
207
|
+
standard deviation filtering is applied.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
np.ndarray
|
|
212
|
+
A binary map where True indicates patches that pass the filter.
|
|
213
|
+
"""
|
|
214
|
+
if std_threshold is not None:
|
|
215
|
+
return (filter_map[0, ...] > mean_threshold) & (
|
|
216
|
+
filter_map[1, ...] > std_threshold
|
|
217
|
+
)
|
|
218
|
+
return filter_map[0, ...] > mean_threshold
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""A protocol for patch filtering."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PatchFilterProtocol(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
An interface for implementing patch filtering strategies.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def filter_out(self, patch: np.ndarray) -> bool:
|
|
15
|
+
"""
|
|
16
|
+
Determine whether to filter out a given patch.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
patch : numpy.NDArray
|
|
21
|
+
The image patch to evaluate.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
bool
|
|
26
|
+
True if the patch should be filtered out (excluded), False otherwise.
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def filter_map(
|
|
32
|
+
image: np.ndarray,
|
|
33
|
+
patch_size: Sequence[int],
|
|
34
|
+
) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Compute a filter map for the entire image based on the patch filtering criteria.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
image : numpy.NDArray
|
|
41
|
+
The full image to evaluate.
|
|
42
|
+
patch_size : Sequence[int]
|
|
43
|
+
The size of the patches to consider.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
numpy.NDArray
|
|
48
|
+
A map where each element is the .
|
|
49
|
+
"""
|
|
50
|
+
...
|