kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Module for loading data from WSI files using the OpenSlide library."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import openslide
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.wsi.backends import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WsiOpenslide(base.Wsi):
|
|
13
|
+
"""Class for loading data from WSI files using the OpenSlide library."""
|
|
14
|
+
|
|
15
|
+
_wsi: openslide.OpenSlide
|
|
16
|
+
|
|
17
|
+
@override
|
|
18
|
+
def open_file(self, file_path: str) -> openslide.OpenSlide:
|
|
19
|
+
return openslide.OpenSlide(file_path)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
@override
|
|
23
|
+
def level_dimensions(self) -> Sequence[Tuple[int, int]]:
|
|
24
|
+
return self._wsi.level_dimensions
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@override
|
|
28
|
+
def level_downsamples(self) -> Sequence[float]:
|
|
29
|
+
return self._wsi.level_downsamples
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@override
|
|
33
|
+
def mpp(self) -> float:
|
|
34
|
+
# TODO: add overwrite_mpp class attribute to allow setting a default value
|
|
35
|
+
if self._wsi.properties.get(openslide.PROPERTY_NAME_MPP_X) and self._wsi.properties.get(
|
|
36
|
+
openslide.PROPERTY_NAME_MPP_Y
|
|
37
|
+
):
|
|
38
|
+
x_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_X])
|
|
39
|
+
y_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_Y])
|
|
40
|
+
elif (
|
|
41
|
+
self._wsi.properties.get("tiff.XResolution")
|
|
42
|
+
and self._wsi.properties.get("tiff.YResolution")
|
|
43
|
+
and self._wsi.properties.get("tiff.ResolutionUnit")
|
|
44
|
+
):
|
|
45
|
+
unit = self._wsi.properties.get("tiff.ResolutionUnit")
|
|
46
|
+
if unit not in _conversion_factor_to_micrometer:
|
|
47
|
+
raise ValueError(f"Unit {unit} not supported.")
|
|
48
|
+
|
|
49
|
+
conversion_factor = float(_conversion_factor_to_micrometer.get(unit)) # type: ignore
|
|
50
|
+
x_mpp = conversion_factor / float(self._wsi.properties["tiff.XResolution"])
|
|
51
|
+
y_mpp = conversion_factor / float(self._wsi.properties["tiff.YResolution"])
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError("`mpp` cannot be obtained for this slide.")
|
|
54
|
+
|
|
55
|
+
return (x_mpp + y_mpp) / 2.0
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def _read_region(
|
|
59
|
+
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
|
|
60
|
+
) -> np.ndarray:
|
|
61
|
+
return np.array(self._wsi.read_region(location, level, size))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
_conversion_factor_to_micrometer = {
|
|
65
|
+
"meter": 10**6,
|
|
66
|
+
"decimeter": 10**5,
|
|
67
|
+
"centimeter": 10**4,
|
|
68
|
+
"millimeter": 10**3,
|
|
69
|
+
"micrometer": 1,
|
|
70
|
+
"nanometer": 10**-3,
|
|
71
|
+
"picometer": 10**-6,
|
|
72
|
+
"femtometer": 10**-9,
|
|
73
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Module for loading data from standard image file formats PIL library."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import PIL.Image
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.wsi.backends import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PILImage(base.Wsi):
|
|
13
|
+
"""Class for loading data from standard image file formats using PIL library."""
|
|
14
|
+
|
|
15
|
+
_wsi: PIL.Image.Image
|
|
16
|
+
|
|
17
|
+
@override
|
|
18
|
+
def open_file(self, file_path: str) -> PIL.Image.Image:
|
|
19
|
+
return PIL.Image.open(file_path).convert("RGB")
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
@override
|
|
23
|
+
def level_dimensions(self) -> Sequence[Tuple[int, int]]:
|
|
24
|
+
return [self._wsi.size]
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@override
|
|
28
|
+
def level_downsamples(self) -> Sequence[float]:
|
|
29
|
+
return [1.0]
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@override
|
|
33
|
+
def mpp(self) -> float:
|
|
34
|
+
if self._overwrite_mpp is None:
|
|
35
|
+
raise ValueError("Please specify the mpp using the `overwrite_mpp` argument.")
|
|
36
|
+
return self._overwrite_mpp
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def _read_region(
|
|
40
|
+
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
|
|
41
|
+
) -> np.ndarray:
|
|
42
|
+
width, height = size[0], size[1]
|
|
43
|
+
patch = self._wsi.crop(
|
|
44
|
+
# (left, upper, right, lower)
|
|
45
|
+
(
|
|
46
|
+
location[0],
|
|
47
|
+
location[1],
|
|
48
|
+
location[0] + width,
|
|
49
|
+
location[1] + height,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
return np.array(patch)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Module for loading data from WSI files using the OpenSlide library."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tiffslide # type: ignore
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.wsi.backends import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WsiTiffslide(base.Wsi):
|
|
13
|
+
"""Class for loading data from WSI files using the TiffSlide library."""
|
|
14
|
+
|
|
15
|
+
_wsi: tiffslide.TiffSlide
|
|
16
|
+
|
|
17
|
+
@override
|
|
18
|
+
def open_file(self, file_path: str) -> tiffslide.TiffSlide:
|
|
19
|
+
return tiffslide.TiffSlide(file_path)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
@override
|
|
23
|
+
def level_dimensions(self) -> Sequence[Tuple[int, int]]:
|
|
24
|
+
return self._wsi.level_dimensions
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@override
|
|
28
|
+
def level_downsamples(self) -> Sequence[float]:
|
|
29
|
+
return self._wsi.level_downsamples
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@override
|
|
33
|
+
def mpp(self) -> float:
|
|
34
|
+
x_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_X])
|
|
35
|
+
y_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_Y])
|
|
36
|
+
return (x_mpp + y_mpp) / 2.0
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def _read_region(
|
|
40
|
+
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
|
|
41
|
+
) -> np.ndarray:
|
|
42
|
+
return np.array(self._wsi.read_region(location, level, size))
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""A module for handling coordinates of patches from a whole-slide image."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import functools
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
|
|
7
|
+
from eva.vision.data.wsi import backends
|
|
8
|
+
from eva.vision.data.wsi.patching import samplers
|
|
9
|
+
from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
|
|
10
|
+
|
|
11
|
+
LRU_CACHE_SIZE = 32
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass
|
|
15
|
+
class PatchCoordinates:
|
|
16
|
+
"""A class to store coordinates of patches from a whole-slide image.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x_y: A list of (x, y) coordinates of the patches (refer to level 0).
|
|
20
|
+
width: The width of the patches, in pixels (refers to level_idx).
|
|
21
|
+
height: The height of the patches, in pixels (refers to level_idx).
|
|
22
|
+
level_idx: The level index at which to extract the patches.
|
|
23
|
+
mask: The foreground mask of the wsi.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
x_y: List[Tuple[int, int]]
|
|
27
|
+
width: int
|
|
28
|
+
height: int
|
|
29
|
+
level_idx: int
|
|
30
|
+
mask: Mask | None = None
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_file(
|
|
34
|
+
cls,
|
|
35
|
+
wsi_path: str,
|
|
36
|
+
width: int,
|
|
37
|
+
height: int,
|
|
38
|
+
sampler: samplers.Sampler,
|
|
39
|
+
target_mpp: float,
|
|
40
|
+
overwrite_mpp: float | None = None,
|
|
41
|
+
backend: str = "openslide",
|
|
42
|
+
) -> "PatchCoordinates":
|
|
43
|
+
"""Create a new instance of PatchCoordinates from a whole-slide image file.
|
|
44
|
+
|
|
45
|
+
Patches will be read from the level that is closest to the specified target_mpp.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
wsi_path: The path to the whole-slide image file.
|
|
49
|
+
width: The width of the patches to be extracted, in pixels.
|
|
50
|
+
height: The height of the patches to be extracted, in pixels.
|
|
51
|
+
target_mpp: The target microns per pixel (mpp) for the patches.
|
|
52
|
+
overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
|
|
53
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
54
|
+
backend: The backend to use for reading the whole-slide images.
|
|
55
|
+
"""
|
|
56
|
+
wsi = backends.wsi_backend(backend)(wsi_path, overwrite_mpp)
|
|
57
|
+
|
|
58
|
+
# Sample patch coordinates at level 0
|
|
59
|
+
mpp_ratio_0 = target_mpp / wsi.mpp
|
|
60
|
+
sample_args = {
|
|
61
|
+
"width": int(mpp_ratio_0 * width),
|
|
62
|
+
"height": int(mpp_ratio_0 * height),
|
|
63
|
+
"layer_shape": wsi.level_dimensions[0],
|
|
64
|
+
}
|
|
65
|
+
if isinstance(sampler, samplers.ForegroundSampler):
|
|
66
|
+
mask_level_idx = get_mask_level(wsi, width, height, target_mpp)
|
|
67
|
+
sample_args["mask"] = get_mask(wsi, mask_level_idx)
|
|
68
|
+
|
|
69
|
+
x_y = list(sampler.sample(**sample_args))
|
|
70
|
+
|
|
71
|
+
# Scale dimensions to level that is closest to the target_mpp
|
|
72
|
+
level_idx = wsi.get_closest_level(target_mpp)
|
|
73
|
+
mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx])
|
|
74
|
+
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
|
|
75
|
+
|
|
76
|
+
return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@functools.lru_cache(LRU_CACHE_SIZE)
|
|
80
|
+
def get_cached_coords(
|
|
81
|
+
file_path: str,
|
|
82
|
+
width: int,
|
|
83
|
+
height: int,
|
|
84
|
+
target_mpp: float,
|
|
85
|
+
overwrite_mpp: float | None,
|
|
86
|
+
sampler: samplers.Sampler,
|
|
87
|
+
backend: str,
|
|
88
|
+
) -> PatchCoordinates:
|
|
89
|
+
"""Get a cached instance of PatchCoordinates for the specified parameters."""
|
|
90
|
+
return PatchCoordinates.from_file(
|
|
91
|
+
wsi_path=file_path,
|
|
92
|
+
width=width,
|
|
93
|
+
height=height,
|
|
94
|
+
target_mpp=target_mpp,
|
|
95
|
+
overwrite_mpp=overwrite_mpp,
|
|
96
|
+
backend=backend,
|
|
97
|
+
sampler=sampler,
|
|
98
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Functions for extracting foreground masks."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.wsi.backends.base import Wsi
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass
|
|
13
|
+
class Mask:
|
|
14
|
+
"""A class to store the mask of a whole-slide image."""
|
|
15
|
+
|
|
16
|
+
mask_array: np.ndarray
|
|
17
|
+
"""Binary mask array where 1s represent the foreground and 0s represent the background."""
|
|
18
|
+
|
|
19
|
+
mask_level_idx: int
|
|
20
|
+
"""WSI level index at which the mask_array was extracted."""
|
|
21
|
+
|
|
22
|
+
scale_factors: Tuple[float, float]
|
|
23
|
+
"""Factors to scale x/y coordinates from mask_level_idx to level 0."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_mask(
|
|
27
|
+
wsi: Wsi,
|
|
28
|
+
mask_level_idx: int,
|
|
29
|
+
saturation_threshold: int = 20,
|
|
30
|
+
median_blur_kernel_size: int | None = None,
|
|
31
|
+
fill_holes: bool = False,
|
|
32
|
+
holes_kernel_size: Tuple[int, int] = (7, 7),
|
|
33
|
+
use_otsu: bool = False,
|
|
34
|
+
) -> Mask:
|
|
35
|
+
"""Generates a binary foreground mask for a given WSI.
|
|
36
|
+
|
|
37
|
+
The is a simplified version of the algorithm proposed in [1] (CLAM):
|
|
38
|
+
1. Convert the image to the HSV color space (easier to seperate specific colors with RGB).
|
|
39
|
+
2. (optional) Apply a median blur to the saturation channel to reduce noise
|
|
40
|
+
& closing small gaps in the mask. While this yields cleaner masks, this step is the most
|
|
41
|
+
computationally expensive and thus disabled by default (CLAM uses a value of 7).
|
|
42
|
+
3. Calculate binary mask by thresholding accross the saturation channel.
|
|
43
|
+
|
|
44
|
+
[1] Lu, Ming Y., et al. "Data-efficient and weakly supervised computational
|
|
45
|
+
pathology on whole-slide images." Nature biomedical engineering 5.6 (2021): 555-570.
|
|
46
|
+
https://github.com/mahmoodlab/CLAM
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
wsi: The WSI object.
|
|
50
|
+
mask_level_idx: The level index of the WSI at which we want to extract the mask.
|
|
51
|
+
saturation_threshold: The threshold value for the saturation channel.
|
|
52
|
+
median_blur_kernel_size: Kernel size for the median blur operation.
|
|
53
|
+
holes_kernel_size: The size of the kernel for morphological operations to fill holes.
|
|
54
|
+
fill_holes: Whether to fill holes in the mask.
|
|
55
|
+
use_otsu: Whether to use Otsu's method for the thresholding operation. If False,
|
|
56
|
+
a fixed threshold value is used.
|
|
57
|
+
|
|
58
|
+
Returns: A Mask object instance.
|
|
59
|
+
"""
|
|
60
|
+
image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx])
|
|
61
|
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
|
62
|
+
image = (
|
|
63
|
+
cv2.medianBlur(image[:, :, 1], median_blur_kernel_size)
|
|
64
|
+
if median_blur_kernel_size
|
|
65
|
+
else image[:, :, 1]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
threshold_type = cv2.THRESH_BINARY + cv2.THRESH_OTSU if use_otsu else cv2.THRESH_BINARY
|
|
69
|
+
_, mask_array = cv2.threshold(image, saturation_threshold, 1, threshold_type)
|
|
70
|
+
|
|
71
|
+
if fill_holes:
|
|
72
|
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, holes_kernel_size)
|
|
73
|
+
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
|
|
74
|
+
contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
|
|
75
|
+
for cnt in contour:
|
|
76
|
+
cv2.drawContours(mask_array, [cnt], 0, (1,), -1)
|
|
77
|
+
|
|
78
|
+
mask_array = mask_array.astype(np.uint8)
|
|
79
|
+
scale_factors = (
|
|
80
|
+
wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0],
|
|
81
|
+
wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1],
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_mask_level(
|
|
88
|
+
wsi: Wsi,
|
|
89
|
+
width: int,
|
|
90
|
+
height: int,
|
|
91
|
+
target_mpp: float,
|
|
92
|
+
min_mask_patch_pixels: int = 3 * 3,
|
|
93
|
+
) -> int:
|
|
94
|
+
"""For performance reasons, we generate the mask at the lowest resolution level possible.
|
|
95
|
+
|
|
96
|
+
However, if minimum resolution level has too few pixels, the patches scaled to that level will
|
|
97
|
+
be too small or even collapse to a single pixel. This function allows to find the lowest
|
|
98
|
+
resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
wsi: The WSI object.
|
|
102
|
+
width: The width of the patches to be extracted, in pixels (at target_mpp).
|
|
103
|
+
height: The height of the patches to be extracted, in pixels.
|
|
104
|
+
target_mpp: The target microns per pixel (mpp) for the patches.
|
|
105
|
+
min_mask_patch_pixels: The minimum number of pixels required for the mask patches.
|
|
106
|
+
Mask patch refers to width / height at target_mpp scaled down to the WSI level
|
|
107
|
+
at which the mask is generated.
|
|
108
|
+
"""
|
|
109
|
+
level_mpps = wsi.mpp * np.array(wsi.level_downsamples)
|
|
110
|
+
mask_level_idx = None
|
|
111
|
+
|
|
112
|
+
for level_idx, level_mpp in reversed(list(enumerate(level_mpps))):
|
|
113
|
+
mpp_ratio = target_mpp / level_mpp
|
|
114
|
+
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
|
|
115
|
+
|
|
116
|
+
if scaled_width * scaled_height >= min_mask_patch_pixels:
|
|
117
|
+
mask_level_idx = level_idx
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
if mask_level_idx is None:
|
|
121
|
+
raise ValueError("No level with the specified minimum number of patch pixels available.")
|
|
122
|
+
|
|
123
|
+
return mask_level_idx
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Patch Sampler API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
|
|
4
|
+
from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
|
|
5
|
+
from eva.vision.data.wsi.patching.samplers.grid import GridSampler
|
|
6
|
+
from eva.vision.data.wsi.patching.samplers.random import RandomSampler
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ForegroundSampler",
|
|
10
|
+
"Sampler",
|
|
11
|
+
"ForegroundGridSampler",
|
|
12
|
+
"GridSampler",
|
|
13
|
+
"RandomSampler",
|
|
14
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def set_seed(seed: int) -> None:
|
|
8
|
+
random.seed(seed)
|
|
9
|
+
np.random.seed(seed)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_grid_coords_and_indices(
|
|
13
|
+
layer_shape: Tuple[int, int],
|
|
14
|
+
width: int,
|
|
15
|
+
height: int,
|
|
16
|
+
overlap: Tuple[int, int],
|
|
17
|
+
shuffle: bool = True,
|
|
18
|
+
seed: int = 42,
|
|
19
|
+
):
|
|
20
|
+
"""Get grid coordinates and indices.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
layer_shape: The shape of the layer.
|
|
24
|
+
width: The width of the patches.
|
|
25
|
+
height: The height of the patches.
|
|
26
|
+
overlap: The overlap between patches in the grid.
|
|
27
|
+
shuffle: Whether to shuffle the indices.
|
|
28
|
+
seed: The random seed.
|
|
29
|
+
"""
|
|
30
|
+
x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
|
|
31
|
+
y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
|
|
32
|
+
x_y = [(x, y) for x in x_range for y in y_range]
|
|
33
|
+
|
|
34
|
+
indices = list(range(len(x_y)))
|
|
35
|
+
if shuffle:
|
|
36
|
+
set_seed(seed)
|
|
37
|
+
np.random.shuffle(indices)
|
|
38
|
+
return x_y, indices
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
|
|
42
|
+
"""Checks if the width / height is bigger than the layer shape.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
width: The width of the patches.
|
|
46
|
+
height: The height of the patches.
|
|
47
|
+
layer_shape: The shape of the layer.
|
|
48
|
+
"""
|
|
49
|
+
if width > layer_shape[0] or height > layer_shape[1]:
|
|
50
|
+
raise ValueError("The width / height cannot be bigger than the layer shape.")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Base classes for samplers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generator, Tuple
|
|
5
|
+
|
|
6
|
+
from eva.vision.data.wsi.patching.mask import Mask
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Sampler(abc.ABC):
|
|
10
|
+
"""Base class for samplers."""
|
|
11
|
+
|
|
12
|
+
@abc.abstractmethod
|
|
13
|
+
def sample(
|
|
14
|
+
self,
|
|
15
|
+
width: int,
|
|
16
|
+
height: int,
|
|
17
|
+
layer_shape: Tuple[int, int],
|
|
18
|
+
mask: Mask | None = None,
|
|
19
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
20
|
+
"""Sample patche coordinates.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
width: The width of the patches.
|
|
24
|
+
height: The height of the patches.
|
|
25
|
+
layer_shape: The shape of the layer.
|
|
26
|
+
mask: Tuple containing the mask array and the scaling factor with respect to the
|
|
27
|
+
provided layer_shape. Optional, only required for samplers with foreground
|
|
28
|
+
filtering.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A generator producing sampled patch coordinates.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ForegroundSampler(Sampler):
|
|
36
|
+
"""Base class for samplers with foreground filtering capabilities."""
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def is_foreground(
|
|
40
|
+
self,
|
|
41
|
+
mask: Mask,
|
|
42
|
+
x: int,
|
|
43
|
+
y: int,
|
|
44
|
+
width: int,
|
|
45
|
+
height: int,
|
|
46
|
+
min_foreground_ratio: float,
|
|
47
|
+
) -> bool:
|
|
48
|
+
"""Check if a patch contains sufficient foreground."""
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Foreground grid sampler."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.wsi.patching.mask import Mask
|
|
6
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ForegroundGridSampler(base.ForegroundSampler):
|
|
10
|
+
"""Sample patches based on a grid, only returning patches containing foreground."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
max_samples: int = 20,
|
|
15
|
+
overlap: Tuple[int, int] = (0, 0),
|
|
16
|
+
min_foreground_ratio: float = 0.35,
|
|
17
|
+
seed: int = 42,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Initializes the sampler.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
max_samples: The maximum number of samples to return.
|
|
23
|
+
overlap: The overlap between patches in the grid.
|
|
24
|
+
min_foreground_ratio: The minimum amount of foreground
|
|
25
|
+
within a sampled patch.
|
|
26
|
+
seed: The random seed.
|
|
27
|
+
"""
|
|
28
|
+
self.max_samples = max_samples
|
|
29
|
+
self.overlap = overlap
|
|
30
|
+
self.min_foreground_ratio = min_foreground_ratio
|
|
31
|
+
self.seed = seed
|
|
32
|
+
|
|
33
|
+
def sample(
|
|
34
|
+
self,
|
|
35
|
+
width: int,
|
|
36
|
+
height: int,
|
|
37
|
+
layer_shape: Tuple[int, int],
|
|
38
|
+
mask: Mask,
|
|
39
|
+
):
|
|
40
|
+
"""Sample patches from a grid containing foreground.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
width: The width of the patches.
|
|
44
|
+
height: The height of the patches.
|
|
45
|
+
layer_shape: The shape of the layer.
|
|
46
|
+
mask: The mask of the image.
|
|
47
|
+
"""
|
|
48
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
49
|
+
x_y, indices = _utils.get_grid_coords_and_indices(
|
|
50
|
+
layer_shape, width, height, self.overlap, seed=self.seed
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
count = 0
|
|
54
|
+
for i in indices:
|
|
55
|
+
if count >= self.max_samples:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if self.is_foreground(
|
|
59
|
+
mask=mask,
|
|
60
|
+
x=x_y[i][0],
|
|
61
|
+
y=x_y[i][1],
|
|
62
|
+
width=width,
|
|
63
|
+
height=height,
|
|
64
|
+
min_foreground_ratio=self.min_foreground_ratio,
|
|
65
|
+
):
|
|
66
|
+
count += 1
|
|
67
|
+
yield x_y[i]
|
|
68
|
+
|
|
69
|
+
def is_foreground(
|
|
70
|
+
self,
|
|
71
|
+
mask: Mask,
|
|
72
|
+
x: int,
|
|
73
|
+
y: int,
|
|
74
|
+
width: int,
|
|
75
|
+
height: int,
|
|
76
|
+
min_foreground_ratio: float,
|
|
77
|
+
) -> bool:
|
|
78
|
+
"""Check if a patch contains sufficient foreground.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
mask: The mask of the image.
|
|
82
|
+
x: The x-coordinate of the patch.
|
|
83
|
+
y: The y-coordinate of the patch.
|
|
84
|
+
width: The width of the patch.
|
|
85
|
+
height: The height of the patch.
|
|
86
|
+
min_foreground_ratio: The minimum amount of foreground in the patch.
|
|
87
|
+
"""
|
|
88
|
+
x_, y_ = self._scale_coords(x, y, mask.scale_factors)
|
|
89
|
+
width_, height_ = self._scale_coords(width, height, mask.scale_factors)
|
|
90
|
+
patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
|
|
91
|
+
return patch_mask.sum() / patch_mask.size >= min_foreground_ratio
|
|
92
|
+
|
|
93
|
+
def _scale_coords(
|
|
94
|
+
self,
|
|
95
|
+
x: int,
|
|
96
|
+
y: int,
|
|
97
|
+
scale_factors: Tuple[float, float],
|
|
98
|
+
) -> Tuple[int, int]:
|
|
99
|
+
return int(x / scale_factors[0]), int(y / scale_factors[1])
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Grid sampler."""
|
|
2
|
+
|
|
3
|
+
from typing import Generator, Tuple
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.wsi.patching.samplers import _utils, base
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GridSampler(base.Sampler):
|
|
9
|
+
"""Sample patches based on a grid.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
max_samples: The maximum number of samples to return.
|
|
13
|
+
overlap: The overlap between patches in the grid.
|
|
14
|
+
seed: The random seed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
max_samples: int | None = None,
|
|
20
|
+
overlap: Tuple[int, int] = (0, 0),
|
|
21
|
+
seed: int = 42,
|
|
22
|
+
):
|
|
23
|
+
"""Initializes the sampler."""
|
|
24
|
+
self.max_samples = max_samples
|
|
25
|
+
self.overlap = overlap
|
|
26
|
+
self.seed = seed
|
|
27
|
+
|
|
28
|
+
def sample(
|
|
29
|
+
self,
|
|
30
|
+
width: int,
|
|
31
|
+
height: int,
|
|
32
|
+
layer_shape: Tuple[int, int],
|
|
33
|
+
) -> Generator[Tuple[int, int], None, None]:
|
|
34
|
+
"""Sample patches from a grid.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
width: The width of the patches.
|
|
38
|
+
height: The height of the patches.
|
|
39
|
+
layer_shape: The shape of the layer.
|
|
40
|
+
"""
|
|
41
|
+
_utils.validate_dimensions(width, height, layer_shape)
|
|
42
|
+
x_y, indices = _utils.get_grid_coords_and_indices(
|
|
43
|
+
layer_shape, width, height, self.overlap, seed=self.seed
|
|
44
|
+
)
|
|
45
|
+
max_samples = len(indices) if self.max_samples is None else self.max_samples
|
|
46
|
+
for i in indices[:max_samples]:
|
|
47
|
+
yield x_y[i]
|