kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
1
+ """Intensity level scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ import torch
7
+ import torchvision.transforms.v2 as torch_transforms
8
+ from torchvision import tv_tensors
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data.transforms.normalization import functional
12
+
13
+
14
+ class RescaleIntensity(torch_transforms.Transform):
15
+ """Stretches or shrinks the image intensity levels."""
16
+
17
+ def __init__(
18
+ self,
19
+ in_range: Tuple[float, float] | None = None,
20
+ out_range: Tuple[float, float] = (0.0, 1.0),
21
+ ) -> None:
22
+ """Initializes the transform.
23
+
24
+ Args:
25
+ in_range: The input data range. If `None`, it will
26
+ fetch the min and max of the input image.
27
+ out_range: The desired intensity range of the output.
28
+ """
29
+ super().__init__()
30
+
31
+ self._in_range = in_range
32
+ self._out_range = out_range
33
+
34
+ @functools.singledispatchmethod
35
+ @override
36
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
37
+ return inpt
38
+
39
+ @_transform.register(torch.Tensor)
40
+ def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
41
+ return functional.rescale_intensity(
42
+ inpt, in_range=self._in_range, out_range=self._out_range
43
+ )
44
+
45
+ @_transform.register(tv_tensors.Image)
46
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
47
+ scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range)
48
+ return tv_tensors.wrap(scaled_inpt, like=inpt)
49
+
50
+ @_transform.register(tv_tensors.BoundingBoxes)
51
+ @_transform.register(tv_tensors.Mask)
52
+ def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
53
+ return inpt
@@ -0,0 +1,16 @@
1
+ """WSI API."""
2
+
3
+ from eva.vision.data.wsi.backends import Wsi, get_cached_wsi, wsi_backend
4
+ from eva.vision.data.wsi.patching.coordinates import PatchCoordinates, get_cached_coords
5
+ from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
6
+
7
+ __all__ = [
8
+ "Wsi",
9
+ "PatchCoordinates",
10
+ "Mask",
11
+ "get_cached_coords",
12
+ "wsi_backend",
13
+ "get_cached_wsi",
14
+ "get_mask",
15
+ "get_mask_level",
16
+ ]
@@ -0,0 +1,69 @@
1
+ """WSI Backends API."""
2
+
3
+ import functools
4
+ import importlib.util
5
+ from typing import Callable
6
+
7
+ from eva.vision.data.wsi.backends.base import Wsi
8
+
9
+ LRU_CACHE_SIZE = 32
10
+
11
+
12
+ def _is_openslide_available() -> bool:
13
+ """Whether the OpenSlide library is available."""
14
+ return importlib.util.find_spec("openslide") is not None
15
+
16
+
17
+ def _is_tiffslide_available() -> bool:
18
+ """Whether the TiffSlide library is available."""
19
+ return importlib.util.find_spec("tiffslide") is not None
20
+
21
+
22
+ def is_backend_available(backend: str) -> bool:
23
+ """Whether the specified backend is available."""
24
+ match backend:
25
+ case "openslide":
26
+ return _is_openslide_available()
27
+ case "tiffslide":
28
+ return _is_tiffslide_available()
29
+ return False
30
+
31
+
32
+ def wsi_backend(backend: str = "openslide") -> Callable[..., Wsi]:
33
+ """Returns the backend to use for reading the whole-slide images."""
34
+ match backend:
35
+ case "openslide":
36
+ if _is_openslide_available():
37
+ from eva.vision.data.wsi.backends.openslide import WsiOpenslide
38
+
39
+ return WsiOpenslide
40
+ else:
41
+ raise ValueError(
42
+ "Missing optional dependency: openslide.\n"
43
+ "Please install using `pip install openslide-python`."
44
+ )
45
+ case "tiffslide":
46
+ if _is_tiffslide_available():
47
+ from eva.vision.data.wsi.backends.tiffslide import WsiTiffslide
48
+
49
+ return WsiTiffslide
50
+ else:
51
+ raise ValueError(
52
+ "Missing optional dependency: tiffslide.\n"
53
+ "Please install using `pip install tiffslide`."
54
+ )
55
+ case "pil":
56
+ from eva.vision.data.wsi.backends.pil import PILImage
57
+
58
+ return PILImage
59
+ case _:
60
+ raise ValueError(f"Unknown WSI backend selected: {backend}")
61
+
62
+
63
+ @functools.lru_cache(LRU_CACHE_SIZE)
64
+ def get_cached_wsi(file_path: str, backend: str, overwrite_mpp: float | None = None) -> Wsi:
65
+ """Returns a cached instance of the whole-slide image backend reader."""
66
+ return wsi_backend(backend)(file_path, overwrite_mpp)
67
+
68
+
69
+ __all__ = ["Wsi", "wsi_backend", "get_cached_wsi", "_is_openslide_available"]
@@ -0,0 +1,115 @@
1
+ """Base Module for loading data from WSI files."""
2
+
3
+ import abc
4
+ from typing import Any, Sequence, Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ class Wsi(abc.ABC):
10
+ """Base class for loading data from Whole Slide Image (WSI) files."""
11
+
12
+ def __init__(self, file_path: str, overwrite_mpp: float | None = None):
13
+ """Initializes a Wsi object.
14
+
15
+ Args:
16
+ file_path: The path to the WSI file.
17
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
18
+ """
19
+ self._wsi = self.open_file(file_path)
20
+ self._overwrite_mpp = overwrite_mpp
21
+
22
+ @abc.abstractmethod
23
+ def open_file(self, file_path: str) -> Any:
24
+ """Opens the WSI file.
25
+
26
+ Args:
27
+ file_path: The path to the WSI file.
28
+ """
29
+
30
+ @property
31
+ @abc.abstractmethod
32
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
33
+ """A list of (width, height) tuples for each level, from highest to lowest resolution."""
34
+
35
+ @property
36
+ @abc.abstractmethod
37
+ def level_downsamples(self) -> Sequence[float]:
38
+ """A list of downsampling factors for each level, relative to the highest resolution."""
39
+
40
+ @property
41
+ @abc.abstractmethod
42
+ def mpp(self) -> float:
43
+ """Microns per pixel at the highest resolution (level 0)."""
44
+
45
+ @abc.abstractmethod
46
+ def _read_region(
47
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
48
+ ) -> np.ndarray:
49
+ """Abstract method to read a region at a specified zoom level."""
50
+
51
+ def read_region(
52
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
53
+ ) -> np.ndarray:
54
+ """Reads and returns image data for a specified region and zoom level.
55
+
56
+ Args:
57
+ location: Top-left corner (x, y) to start reading at level 0.
58
+ level: WSI level to read from.
59
+ size: Region size as (width, height) in pixels at the selected read level.
60
+ Remember to scale the size correctly.
61
+ """
62
+ self._verify_location(location, size)
63
+ data = self._read_region(location, level, size)
64
+ return self._read_postprocess(data)
65
+
66
+ def get_closest_level(self, target_mpp: float) -> int:
67
+ """Calculate the slide level that is closest to the target mpp.
68
+
69
+ Args:
70
+ slide: The whole-slide image object.
71
+ target_mpp: The target microns per pixel (mpp) value.
72
+ """
73
+ # Calculate the mpp for each level
74
+ level_mpps = self.mpp * np.array(self.level_downsamples)
75
+
76
+ # Ignore levels with higher mpp
77
+ level_mpps_filtered = level_mpps.copy()
78
+ level_mpps_filtered[level_mpps_filtered > target_mpp] = 0
79
+
80
+ if level_mpps_filtered.max() == 0:
81
+ # When all levels have higher mpp than target_mpp return the level with lowest mpp
82
+ level_idx = np.argmin(level_mpps)
83
+ else:
84
+ level_idx = np.argmax(level_mpps_filtered)
85
+
86
+ return int(level_idx)
87
+
88
+ def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None:
89
+ """Verifies that the requested region is within the slide dimensions.
90
+
91
+ Args:
92
+ location: Top-left corner (x, y) to start reading at level 0.
93
+ size: Region size as (width, height) in pixels at the selected read level.
94
+ """
95
+ x_max, y_max = self.level_dimensions[0]
96
+ x_scale = x_max / self.level_dimensions[0][0]
97
+ y_scale = y_max / self.level_dimensions[0][1]
98
+
99
+ if (
100
+ int(location[0] + x_scale * size[0]) > x_max
101
+ or int(location[1] + y_scale * size[1]) > y_max
102
+ ):
103
+ raise ValueError(f"Out of bounds region: {location}, {size}")
104
+
105
+ def _read_postprocess(self, data: np.ndarray) -> np.ndarray:
106
+ """Post-processes the read region data.
107
+
108
+ Args:
109
+ data: The read region data as a numpy array of shape (height, width, channels).
110
+ """
111
+ # Change color to white where the alpha channel is 0
112
+ if data.shape[2] == 4:
113
+ data[data[:, :, 3] == 0] = 255
114
+
115
+ return data[:, :, :3]
@@ -0,0 +1,73 @@
1
+ """Module for loading data from WSI files using the OpenSlide library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import openslide
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class WsiOpenslide(base.Wsi):
13
+ """Class for loading data from WSI files using the OpenSlide library."""
14
+
15
+ _wsi: openslide.OpenSlide
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> openslide.OpenSlide:
19
+ return openslide.OpenSlide(file_path)
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return self._wsi.level_dimensions
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return self._wsi.level_downsamples
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ # TODO: add overwrite_mpp class attribute to allow setting a default value
35
+ if self._wsi.properties.get(openslide.PROPERTY_NAME_MPP_X) and self._wsi.properties.get(
36
+ openslide.PROPERTY_NAME_MPP_Y
37
+ ):
38
+ x_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_X])
39
+ y_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_Y])
40
+ elif (
41
+ self._wsi.properties.get("tiff.XResolution")
42
+ and self._wsi.properties.get("tiff.YResolution")
43
+ and self._wsi.properties.get("tiff.ResolutionUnit")
44
+ ):
45
+ unit = self._wsi.properties.get("tiff.ResolutionUnit")
46
+ if unit not in _conversion_factor_to_micrometer:
47
+ raise ValueError(f"Unit {unit} not supported.")
48
+
49
+ conversion_factor = float(_conversion_factor_to_micrometer.get(unit)) # type: ignore
50
+ x_mpp = conversion_factor / float(self._wsi.properties["tiff.XResolution"])
51
+ y_mpp = conversion_factor / float(self._wsi.properties["tiff.YResolution"])
52
+ else:
53
+ raise ValueError("`mpp` cannot be obtained for this slide.")
54
+
55
+ return (x_mpp + y_mpp) / 2.0
56
+
57
+ @override
58
+ def _read_region(
59
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
60
+ ) -> np.ndarray:
61
+ return np.array(self._wsi.read_region(location, level, size))
62
+
63
+
64
+ _conversion_factor_to_micrometer = {
65
+ "meter": 10**6,
66
+ "decimeter": 10**5,
67
+ "centimeter": 10**4,
68
+ "millimeter": 10**3,
69
+ "micrometer": 1,
70
+ "nanometer": 10**-3,
71
+ "picometer": 10**-6,
72
+ "femtometer": 10**-9,
73
+ }
@@ -0,0 +1,52 @@
1
+ """Module for loading data from standard image file formats PIL library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class PILImage(base.Wsi):
13
+ """Class for loading data from standard image file formats using PIL library."""
14
+
15
+ _wsi: PIL.Image.Image
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> PIL.Image.Image:
19
+ return PIL.Image.open(file_path).convert("RGB")
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return [self._wsi.size]
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return [1.0]
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ if self._overwrite_mpp is None:
35
+ raise ValueError("Please specify the mpp using the `overwrite_mpp` argument.")
36
+ return self._overwrite_mpp
37
+
38
+ @override
39
+ def _read_region(
40
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
41
+ ) -> np.ndarray:
42
+ width, height = size[0], size[1]
43
+ patch = self._wsi.crop(
44
+ # (left, upper, right, lower)
45
+ (
46
+ location[0],
47
+ location[1],
48
+ location[0] + width,
49
+ location[1] + height,
50
+ )
51
+ )
52
+ return np.array(patch)
@@ -0,0 +1,42 @@
1
+ """Module for loading data from WSI files using the OpenSlide library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import tiffslide # type: ignore
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class WsiTiffslide(base.Wsi):
13
+ """Class for loading data from WSI files using the TiffSlide library."""
14
+
15
+ _wsi: tiffslide.TiffSlide
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> tiffslide.TiffSlide:
19
+ return tiffslide.TiffSlide(file_path)
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return self._wsi.level_dimensions
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return self._wsi.level_downsamples
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ x_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_X])
35
+ y_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_Y])
36
+ return (x_mpp + y_mpp) / 2.0
37
+
38
+ @override
39
+ def _read_region(
40
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
41
+ ) -> np.ndarray:
42
+ return np.array(self._wsi.read_region(location, level, size))
@@ -0,0 +1,6 @@
1
+ """WSI Patching API."""
2
+
3
+ from eva.vision.data.wsi.patching import samplers
4
+ from eva.vision.data.wsi.patching.coordinates import PatchCoordinates
5
+
6
+ __all__ = ["samplers", "PatchCoordinates"]
@@ -0,0 +1,98 @@
1
+ """A module for handling coordinates of patches from a whole-slide image."""
2
+
3
+ import dataclasses
4
+ import functools
5
+ from typing import List, Tuple
6
+
7
+ from eva.vision.data.wsi import backends
8
+ from eva.vision.data.wsi.patching import samplers
9
+ from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
10
+
11
+ LRU_CACHE_SIZE = 32
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class PatchCoordinates:
16
+ """A class to store coordinates of patches from a whole-slide image.
17
+
18
+ Args:
19
+ x_y: A list of (x, y) coordinates of the patches (refer to level 0).
20
+ width: The width of the patches, in pixels (refers to level_idx).
21
+ height: The height of the patches, in pixels (refers to level_idx).
22
+ level_idx: The level index at which to extract the patches.
23
+ mask: The foreground mask of the wsi.
24
+ """
25
+
26
+ x_y: List[Tuple[int, int]]
27
+ width: int
28
+ height: int
29
+ level_idx: int
30
+ mask: Mask | None = None
31
+
32
+ @classmethod
33
+ def from_file(
34
+ cls,
35
+ wsi_path: str,
36
+ width: int,
37
+ height: int,
38
+ sampler: samplers.Sampler,
39
+ target_mpp: float,
40
+ overwrite_mpp: float | None = None,
41
+ backend: str = "openslide",
42
+ ) -> "PatchCoordinates":
43
+ """Create a new instance of PatchCoordinates from a whole-slide image file.
44
+
45
+ Patches will be read from the level that is closest to the specified target_mpp.
46
+
47
+ Args:
48
+ wsi_path: The path to the whole-slide image file.
49
+ width: The width of the patches to be extracted, in pixels.
50
+ height: The height of the patches to be extracted, in pixels.
51
+ target_mpp: The target microns per pixel (mpp) for the patches.
52
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
53
+ sampler: The sampler to use for sampling patch coordinates.
54
+ backend: The backend to use for reading the whole-slide images.
55
+ """
56
+ wsi = backends.wsi_backend(backend)(wsi_path, overwrite_mpp)
57
+
58
+ # Sample patch coordinates at level 0
59
+ mpp_ratio_0 = target_mpp / wsi.mpp
60
+ sample_args = {
61
+ "width": int(mpp_ratio_0 * width),
62
+ "height": int(mpp_ratio_0 * height),
63
+ "layer_shape": wsi.level_dimensions[0],
64
+ }
65
+ if isinstance(sampler, samplers.ForegroundSampler):
66
+ mask_level_idx = get_mask_level(wsi, width, height, target_mpp)
67
+ sample_args["mask"] = get_mask(wsi, mask_level_idx)
68
+
69
+ x_y = list(sampler.sample(**sample_args))
70
+
71
+ # Scale dimensions to level that is closest to the target_mpp
72
+ level_idx = wsi.get_closest_level(target_mpp)
73
+ mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx])
74
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
75
+
76
+ return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
77
+
78
+
79
+ @functools.lru_cache(LRU_CACHE_SIZE)
80
+ def get_cached_coords(
81
+ file_path: str,
82
+ width: int,
83
+ height: int,
84
+ target_mpp: float,
85
+ overwrite_mpp: float | None,
86
+ sampler: samplers.Sampler,
87
+ backend: str,
88
+ ) -> PatchCoordinates:
89
+ """Get a cached instance of PatchCoordinates for the specified parameters."""
90
+ return PatchCoordinates.from_file(
91
+ wsi_path=file_path,
92
+ width=width,
93
+ height=height,
94
+ target_mpp=target_mpp,
95
+ overwrite_mpp=overwrite_mpp,
96
+ backend=backend,
97
+ sampler=sampler,
98
+ )
@@ -0,0 +1,123 @@
1
+ """Functions for extracting foreground masks."""
2
+
3
+ import dataclasses
4
+ from typing import Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from eva.vision.data.wsi.backends.base import Wsi
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Mask:
14
+ """A class to store the mask of a whole-slide image."""
15
+
16
+ mask_array: np.ndarray
17
+ """Binary mask array where 1s represent the foreground and 0s represent the background."""
18
+
19
+ mask_level_idx: int
20
+ """WSI level index at which the mask_array was extracted."""
21
+
22
+ scale_factors: Tuple[float, float]
23
+ """Factors to scale x/y coordinates from mask_level_idx to level 0."""
24
+
25
+
26
+ def get_mask(
27
+ wsi: Wsi,
28
+ mask_level_idx: int,
29
+ saturation_threshold: int = 20,
30
+ median_blur_kernel_size: int | None = None,
31
+ fill_holes: bool = False,
32
+ holes_kernel_size: Tuple[int, int] = (7, 7),
33
+ use_otsu: bool = False,
34
+ ) -> Mask:
35
+ """Generates a binary foreground mask for a given WSI.
36
+
37
+ The is a simplified version of the algorithm proposed in [1] (CLAM):
38
+ 1. Convert the image to the HSV color space (easier to seperate specific colors with RGB).
39
+ 2. (optional) Apply a median blur to the saturation channel to reduce noise
40
+ & closing small gaps in the mask. While this yields cleaner masks, this step is the most
41
+ computationally expensive and thus disabled by default (CLAM uses a value of 7).
42
+ 3. Calculate binary mask by thresholding accross the saturation channel.
43
+
44
+ [1] Lu, Ming Y., et al. "Data-efficient and weakly supervised computational
45
+ pathology on whole-slide images." Nature biomedical engineering 5.6 (2021): 555-570.
46
+ https://github.com/mahmoodlab/CLAM
47
+
48
+ Args:
49
+ wsi: The WSI object.
50
+ mask_level_idx: The level index of the WSI at which we want to extract the mask.
51
+ saturation_threshold: The threshold value for the saturation channel.
52
+ median_blur_kernel_size: Kernel size for the median blur operation.
53
+ holes_kernel_size: The size of the kernel for morphological operations to fill holes.
54
+ fill_holes: Whether to fill holes in the mask.
55
+ use_otsu: Whether to use Otsu's method for the thresholding operation. If False,
56
+ a fixed threshold value is used.
57
+
58
+ Returns: A Mask object instance.
59
+ """
60
+ image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx])
61
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
62
+ image = (
63
+ cv2.medianBlur(image[:, :, 1], median_blur_kernel_size)
64
+ if median_blur_kernel_size
65
+ else image[:, :, 1]
66
+ )
67
+
68
+ threshold_type = cv2.THRESH_BINARY + cv2.THRESH_OTSU if use_otsu else cv2.THRESH_BINARY
69
+ _, mask_array = cv2.threshold(image, saturation_threshold, 1, threshold_type)
70
+
71
+ if fill_holes:
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, holes_kernel_size)
73
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
74
+ contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
75
+ for cnt in contour:
76
+ cv2.drawContours(mask_array, [cnt], 0, (1,), -1)
77
+
78
+ mask_array = mask_array.astype(np.uint8)
79
+ scale_factors = (
80
+ wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0],
81
+ wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1],
82
+ )
83
+
84
+ return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors)
85
+
86
+
87
+ def get_mask_level(
88
+ wsi: Wsi,
89
+ width: int,
90
+ height: int,
91
+ target_mpp: float,
92
+ min_mask_patch_pixels: int = 3 * 3,
93
+ ) -> int:
94
+ """For performance reasons, we generate the mask at the lowest resolution level possible.
95
+
96
+ However, if minimum resolution level has too few pixels, the patches scaled to that level will
97
+ be too small or even collapse to a single pixel. This function allows to find the lowest
98
+ resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels.
99
+
100
+ Args:
101
+ wsi: The WSI object.
102
+ width: The width of the patches to be extracted, in pixels (at target_mpp).
103
+ height: The height of the patches to be extracted, in pixels.
104
+ target_mpp: The target microns per pixel (mpp) for the patches.
105
+ min_mask_patch_pixels: The minimum number of pixels required for the mask patches.
106
+ Mask patch refers to width / height at target_mpp scaled down to the WSI level
107
+ at which the mask is generated.
108
+ """
109
+ level_mpps = wsi.mpp * np.array(wsi.level_downsamples)
110
+ mask_level_idx = None
111
+
112
+ for level_idx, level_mpp in reversed(list(enumerate(level_mpps))):
113
+ mpp_ratio = target_mpp / level_mpp
114
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
115
+
116
+ if scaled_width * scaled_height >= min_mask_patch_pixels:
117
+ mask_level_idx = level_idx
118
+ break
119
+
120
+ if mask_level_idx is None:
121
+ raise ValueError("No level with the specified minimum number of patch pixels available.")
122
+
123
+ return mask_level_idx
@@ -0,0 +1,14 @@
1
+ """Patch Sampler API."""
2
+
3
+ from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
4
+ from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
5
+ from eva.vision.data.wsi.patching.samplers.grid import GridSampler
6
+ from eva.vision.data.wsi.patching.samplers.random import RandomSampler
7
+
8
+ __all__ = [
9
+ "ForegroundSampler",
10
+ "Sampler",
11
+ "ForegroundGridSampler",
12
+ "GridSampler",
13
+ "RandomSampler",
14
+ ]