kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,73 @@
1
+ """Module for loading data from WSI files using the OpenSlide library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import openslide
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class WsiOpenslide(base.Wsi):
13
+ """Class for loading data from WSI files using the OpenSlide library."""
14
+
15
+ _wsi: openslide.OpenSlide
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> openslide.OpenSlide:
19
+ return openslide.OpenSlide(file_path)
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return self._wsi.level_dimensions
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return self._wsi.level_downsamples
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ # TODO: add overwrite_mpp class attribute to allow setting a default value
35
+ if self._wsi.properties.get(openslide.PROPERTY_NAME_MPP_X) and self._wsi.properties.get(
36
+ openslide.PROPERTY_NAME_MPP_Y
37
+ ):
38
+ x_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_X])
39
+ y_mpp = float(self._wsi.properties[openslide.PROPERTY_NAME_MPP_Y])
40
+ elif (
41
+ self._wsi.properties.get("tiff.XResolution")
42
+ and self._wsi.properties.get("tiff.YResolution")
43
+ and self._wsi.properties.get("tiff.ResolutionUnit")
44
+ ):
45
+ unit = self._wsi.properties.get("tiff.ResolutionUnit")
46
+ if unit not in _conversion_factor_to_micrometer:
47
+ raise ValueError(f"Unit {unit} not supported.")
48
+
49
+ conversion_factor = float(_conversion_factor_to_micrometer.get(unit)) # type: ignore
50
+ x_mpp = conversion_factor / float(self._wsi.properties["tiff.XResolution"])
51
+ y_mpp = conversion_factor / float(self._wsi.properties["tiff.YResolution"])
52
+ else:
53
+ raise ValueError("`mpp` cannot be obtained for this slide.")
54
+
55
+ return (x_mpp + y_mpp) / 2.0
56
+
57
+ @override
58
+ def _read_region(
59
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
60
+ ) -> np.ndarray:
61
+ return np.array(self._wsi.read_region(location, level, size))
62
+
63
+
64
+ _conversion_factor_to_micrometer = {
65
+ "meter": 10**6,
66
+ "decimeter": 10**5,
67
+ "centimeter": 10**4,
68
+ "millimeter": 10**3,
69
+ "micrometer": 1,
70
+ "nanometer": 10**-3,
71
+ "picometer": 10**-6,
72
+ "femtometer": 10**-9,
73
+ }
@@ -0,0 +1,52 @@
1
+ """Module for loading data from standard image file formats PIL library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class PILImage(base.Wsi):
13
+ """Class for loading data from standard image file formats using PIL library."""
14
+
15
+ _wsi: PIL.Image.Image
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> PIL.Image.Image:
19
+ return PIL.Image.open(file_path).convert("RGB")
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return [self._wsi.size]
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return [1.0]
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ if self._overwrite_mpp is None:
35
+ raise ValueError("Please specify the mpp using the `overwrite_mpp` argument.")
36
+ return self._overwrite_mpp
37
+
38
+ @override
39
+ def _read_region(
40
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
41
+ ) -> np.ndarray:
42
+ width, height = size[0], size[1]
43
+ patch = self._wsi.crop(
44
+ # (left, upper, right, lower)
45
+ (
46
+ location[0],
47
+ location[1],
48
+ location[0] + width,
49
+ location[1] + height,
50
+ )
51
+ )
52
+ return np.array(patch)
@@ -0,0 +1,42 @@
1
+ """Module for loading data from WSI files using the OpenSlide library."""
2
+
3
+ from typing import Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import tiffslide # type: ignore
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.wsi.backends import base
10
+
11
+
12
+ class WsiTiffslide(base.Wsi):
13
+ """Class for loading data from WSI files using the TiffSlide library."""
14
+
15
+ _wsi: tiffslide.TiffSlide
16
+
17
+ @override
18
+ def open_file(self, file_path: str) -> tiffslide.TiffSlide:
19
+ return tiffslide.TiffSlide(file_path)
20
+
21
+ @property
22
+ @override
23
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
24
+ return self._wsi.level_dimensions
25
+
26
+ @property
27
+ @override
28
+ def level_downsamples(self) -> Sequence[float]:
29
+ return self._wsi.level_downsamples
30
+
31
+ @property
32
+ @override
33
+ def mpp(self) -> float:
34
+ x_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_X])
35
+ y_mpp = float(self._wsi.properties[tiffslide.PROPERTY_NAME_MPP_Y])
36
+ return (x_mpp + y_mpp) / 2.0
37
+
38
+ @override
39
+ def _read_region(
40
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
41
+ ) -> np.ndarray:
42
+ return np.array(self._wsi.read_region(location, level, size))
@@ -0,0 +1,6 @@
1
+ """WSI Patching API."""
2
+
3
+ from eva.vision.data.wsi.patching import samplers
4
+ from eva.vision.data.wsi.patching.coordinates import PatchCoordinates
5
+
6
+ __all__ = ["samplers", "PatchCoordinates"]
@@ -0,0 +1,98 @@
1
+ """A module for handling coordinates of patches from a whole-slide image."""
2
+
3
+ import dataclasses
4
+ import functools
5
+ from typing import List, Tuple
6
+
7
+ from eva.vision.data.wsi import backends
8
+ from eva.vision.data.wsi.patching import samplers
9
+ from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
10
+
11
+ LRU_CACHE_SIZE = 32
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class PatchCoordinates:
16
+ """A class to store coordinates of patches from a whole-slide image.
17
+
18
+ Args:
19
+ x_y: A list of (x, y) coordinates of the patches (refer to level 0).
20
+ width: The width of the patches, in pixels (refers to level_idx).
21
+ height: The height of the patches, in pixels (refers to level_idx).
22
+ level_idx: The level index at which to extract the patches.
23
+ mask: The foreground mask of the wsi.
24
+ """
25
+
26
+ x_y: List[Tuple[int, int]]
27
+ width: int
28
+ height: int
29
+ level_idx: int
30
+ mask: Mask | None = None
31
+
32
+ @classmethod
33
+ def from_file(
34
+ cls,
35
+ wsi_path: str,
36
+ width: int,
37
+ height: int,
38
+ sampler: samplers.Sampler,
39
+ target_mpp: float,
40
+ overwrite_mpp: float | None = None,
41
+ backend: str = "openslide",
42
+ ) -> "PatchCoordinates":
43
+ """Create a new instance of PatchCoordinates from a whole-slide image file.
44
+
45
+ Patches will be read from the level that is closest to the specified target_mpp.
46
+
47
+ Args:
48
+ wsi_path: The path to the whole-slide image file.
49
+ width: The width of the patches to be extracted, in pixels.
50
+ height: The height of the patches to be extracted, in pixels.
51
+ target_mpp: The target microns per pixel (mpp) for the patches.
52
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
53
+ sampler: The sampler to use for sampling patch coordinates.
54
+ backend: The backend to use for reading the whole-slide images.
55
+ """
56
+ wsi = backends.wsi_backend(backend)(wsi_path, overwrite_mpp)
57
+
58
+ # Sample patch coordinates at level 0
59
+ mpp_ratio_0 = target_mpp / wsi.mpp
60
+ sample_args = {
61
+ "width": int(mpp_ratio_0 * width),
62
+ "height": int(mpp_ratio_0 * height),
63
+ "layer_shape": wsi.level_dimensions[0],
64
+ }
65
+ if isinstance(sampler, samplers.ForegroundSampler):
66
+ mask_level_idx = get_mask_level(wsi, width, height, target_mpp)
67
+ sample_args["mask"] = get_mask(wsi, mask_level_idx)
68
+
69
+ x_y = list(sampler.sample(**sample_args))
70
+
71
+ # Scale dimensions to level that is closest to the target_mpp
72
+ level_idx = wsi.get_closest_level(target_mpp)
73
+ mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx])
74
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
75
+
76
+ return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
77
+
78
+
79
+ @functools.lru_cache(LRU_CACHE_SIZE)
80
+ def get_cached_coords(
81
+ file_path: str,
82
+ width: int,
83
+ height: int,
84
+ target_mpp: float,
85
+ overwrite_mpp: float | None,
86
+ sampler: samplers.Sampler,
87
+ backend: str,
88
+ ) -> PatchCoordinates:
89
+ """Get a cached instance of PatchCoordinates for the specified parameters."""
90
+ return PatchCoordinates.from_file(
91
+ wsi_path=file_path,
92
+ width=width,
93
+ height=height,
94
+ target_mpp=target_mpp,
95
+ overwrite_mpp=overwrite_mpp,
96
+ backend=backend,
97
+ sampler=sampler,
98
+ )
@@ -0,0 +1,123 @@
1
+ """Functions for extracting foreground masks."""
2
+
3
+ import dataclasses
4
+ from typing import Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from eva.vision.data.wsi.backends.base import Wsi
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Mask:
14
+ """A class to store the mask of a whole-slide image."""
15
+
16
+ mask_array: np.ndarray
17
+ """Binary mask array where 1s represent the foreground and 0s represent the background."""
18
+
19
+ mask_level_idx: int
20
+ """WSI level index at which the mask_array was extracted."""
21
+
22
+ scale_factors: Tuple[float, float]
23
+ """Factors to scale x/y coordinates from mask_level_idx to level 0."""
24
+
25
+
26
+ def get_mask(
27
+ wsi: Wsi,
28
+ mask_level_idx: int,
29
+ saturation_threshold: int = 20,
30
+ median_blur_kernel_size: int | None = None,
31
+ fill_holes: bool = False,
32
+ holes_kernel_size: Tuple[int, int] = (7, 7),
33
+ use_otsu: bool = False,
34
+ ) -> Mask:
35
+ """Generates a binary foreground mask for a given WSI.
36
+
37
+ The is a simplified version of the algorithm proposed in [1] (CLAM):
38
+ 1. Convert the image to the HSV color space (easier to seperate specific colors with RGB).
39
+ 2. (optional) Apply a median blur to the saturation channel to reduce noise
40
+ & closing small gaps in the mask. While this yields cleaner masks, this step is the most
41
+ computationally expensive and thus disabled by default (CLAM uses a value of 7).
42
+ 3. Calculate binary mask by thresholding accross the saturation channel.
43
+
44
+ [1] Lu, Ming Y., et al. "Data-efficient and weakly supervised computational
45
+ pathology on whole-slide images." Nature biomedical engineering 5.6 (2021): 555-570.
46
+ https://github.com/mahmoodlab/CLAM
47
+
48
+ Args:
49
+ wsi: The WSI object.
50
+ mask_level_idx: The level index of the WSI at which we want to extract the mask.
51
+ saturation_threshold: The threshold value for the saturation channel.
52
+ median_blur_kernel_size: Kernel size for the median blur operation.
53
+ holes_kernel_size: The size of the kernel for morphological operations to fill holes.
54
+ fill_holes: Whether to fill holes in the mask.
55
+ use_otsu: Whether to use Otsu's method for the thresholding operation. If False,
56
+ a fixed threshold value is used.
57
+
58
+ Returns: A Mask object instance.
59
+ """
60
+ image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx])
61
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
62
+ image = (
63
+ cv2.medianBlur(image[:, :, 1], median_blur_kernel_size)
64
+ if median_blur_kernel_size
65
+ else image[:, :, 1]
66
+ )
67
+
68
+ threshold_type = cv2.THRESH_BINARY + cv2.THRESH_OTSU if use_otsu else cv2.THRESH_BINARY
69
+ _, mask_array = cv2.threshold(image, saturation_threshold, 1, threshold_type)
70
+
71
+ if fill_holes:
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, holes_kernel_size)
73
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
74
+ contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
75
+ for cnt in contour:
76
+ cv2.drawContours(mask_array, [cnt], 0, (1,), -1)
77
+
78
+ mask_array = mask_array.astype(np.uint8)
79
+ scale_factors = (
80
+ wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0],
81
+ wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1],
82
+ )
83
+
84
+ return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors)
85
+
86
+
87
+ def get_mask_level(
88
+ wsi: Wsi,
89
+ width: int,
90
+ height: int,
91
+ target_mpp: float,
92
+ min_mask_patch_pixels: int = 3 * 3,
93
+ ) -> int:
94
+ """For performance reasons, we generate the mask at the lowest resolution level possible.
95
+
96
+ However, if minimum resolution level has too few pixels, the patches scaled to that level will
97
+ be too small or even collapse to a single pixel. This function allows to find the lowest
98
+ resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels.
99
+
100
+ Args:
101
+ wsi: The WSI object.
102
+ width: The width of the patches to be extracted, in pixels (at target_mpp).
103
+ height: The height of the patches to be extracted, in pixels.
104
+ target_mpp: The target microns per pixel (mpp) for the patches.
105
+ min_mask_patch_pixels: The minimum number of pixels required for the mask patches.
106
+ Mask patch refers to width / height at target_mpp scaled down to the WSI level
107
+ at which the mask is generated.
108
+ """
109
+ level_mpps = wsi.mpp * np.array(wsi.level_downsamples)
110
+ mask_level_idx = None
111
+
112
+ for level_idx, level_mpp in reversed(list(enumerate(level_mpps))):
113
+ mpp_ratio = target_mpp / level_mpp
114
+ scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)
115
+
116
+ if scaled_width * scaled_height >= min_mask_patch_pixels:
117
+ mask_level_idx = level_idx
118
+ break
119
+
120
+ if mask_level_idx is None:
121
+ raise ValueError("No level with the specified minimum number of patch pixels available.")
122
+
123
+ return mask_level_idx
@@ -0,0 +1,14 @@
1
+ """Patch Sampler API."""
2
+
3
+ from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
4
+ from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
5
+ from eva.vision.data.wsi.patching.samplers.grid import GridSampler
6
+ from eva.vision.data.wsi.patching.samplers.random import RandomSampler
7
+
8
+ __all__ = [
9
+ "ForegroundSampler",
10
+ "Sampler",
11
+ "ForegroundGridSampler",
12
+ "GridSampler",
13
+ "RandomSampler",
14
+ ]
@@ -0,0 +1,50 @@
1
+ import random
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+
6
+
7
+ def set_seed(seed: int) -> None:
8
+ random.seed(seed)
9
+ np.random.seed(seed)
10
+
11
+
12
+ def get_grid_coords_and_indices(
13
+ layer_shape: Tuple[int, int],
14
+ width: int,
15
+ height: int,
16
+ overlap: Tuple[int, int],
17
+ shuffle: bool = True,
18
+ seed: int = 42,
19
+ ):
20
+ """Get grid coordinates and indices.
21
+
22
+ Args:
23
+ layer_shape: The shape of the layer.
24
+ width: The width of the patches.
25
+ height: The height of the patches.
26
+ overlap: The overlap between patches in the grid.
27
+ shuffle: Whether to shuffle the indices.
28
+ seed: The random seed.
29
+ """
30
+ x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
31
+ y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
32
+ x_y = [(x, y) for x in x_range for y in y_range]
33
+
34
+ indices = list(range(len(x_y)))
35
+ if shuffle:
36
+ set_seed(seed)
37
+ np.random.shuffle(indices)
38
+ return x_y, indices
39
+
40
+
41
+ def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
42
+ """Checks if the width / height is bigger than the layer shape.
43
+
44
+ Args:
45
+ width: The width of the patches.
46
+ height: The height of the patches.
47
+ layer_shape: The shape of the layer.
48
+ """
49
+ if width > layer_shape[0] or height > layer_shape[1]:
50
+ raise ValueError("The width / height cannot be bigger than the layer shape.")
@@ -0,0 +1,48 @@
1
+ """Base classes for samplers."""
2
+
3
+ import abc
4
+ from typing import Generator, Tuple
5
+
6
+ from eva.vision.data.wsi.patching.mask import Mask
7
+
8
+
9
+ class Sampler(abc.ABC):
10
+ """Base class for samplers."""
11
+
12
+ @abc.abstractmethod
13
+ def sample(
14
+ self,
15
+ width: int,
16
+ height: int,
17
+ layer_shape: Tuple[int, int],
18
+ mask: Mask | None = None,
19
+ ) -> Generator[Tuple[int, int], None, None]:
20
+ """Sample patche coordinates.
21
+
22
+ Args:
23
+ width: The width of the patches.
24
+ height: The height of the patches.
25
+ layer_shape: The shape of the layer.
26
+ mask: Tuple containing the mask array and the scaling factor with respect to the
27
+ provided layer_shape. Optional, only required for samplers with foreground
28
+ filtering.
29
+
30
+ Returns:
31
+ A generator producing sampled patch coordinates.
32
+ """
33
+
34
+
35
+ class ForegroundSampler(Sampler):
36
+ """Base class for samplers with foreground filtering capabilities."""
37
+
38
+ @abc.abstractmethod
39
+ def is_foreground(
40
+ self,
41
+ mask: Mask,
42
+ x: int,
43
+ y: int,
44
+ width: int,
45
+ height: int,
46
+ min_foreground_ratio: float,
47
+ ) -> bool:
48
+ """Check if a patch contains sufficient foreground."""
@@ -0,0 +1,99 @@
1
+ """Foreground grid sampler."""
2
+
3
+ from typing import Tuple
4
+
5
+ from eva.vision.data.wsi.patching.mask import Mask
6
+ from eva.vision.data.wsi.patching.samplers import _utils, base
7
+
8
+
9
+ class ForegroundGridSampler(base.ForegroundSampler):
10
+ """Sample patches based on a grid, only returning patches containing foreground."""
11
+
12
+ def __init__(
13
+ self,
14
+ max_samples: int = 20,
15
+ overlap: Tuple[int, int] = (0, 0),
16
+ min_foreground_ratio: float = 0.35,
17
+ seed: int = 42,
18
+ ) -> None:
19
+ """Initializes the sampler.
20
+
21
+ Args:
22
+ max_samples: The maximum number of samples to return.
23
+ overlap: The overlap between patches in the grid.
24
+ min_foreground_ratio: The minimum amount of foreground
25
+ within a sampled patch.
26
+ seed: The random seed.
27
+ """
28
+ self.max_samples = max_samples
29
+ self.overlap = overlap
30
+ self.min_foreground_ratio = min_foreground_ratio
31
+ self.seed = seed
32
+
33
+ def sample(
34
+ self,
35
+ width: int,
36
+ height: int,
37
+ layer_shape: Tuple[int, int],
38
+ mask: Mask,
39
+ ):
40
+ """Sample patches from a grid containing foreground.
41
+
42
+ Args:
43
+ width: The width of the patches.
44
+ height: The height of the patches.
45
+ layer_shape: The shape of the layer.
46
+ mask: The mask of the image.
47
+ """
48
+ _utils.validate_dimensions(width, height, layer_shape)
49
+ x_y, indices = _utils.get_grid_coords_and_indices(
50
+ layer_shape, width, height, self.overlap, seed=self.seed
51
+ )
52
+
53
+ count = 0
54
+ for i in indices:
55
+ if count >= self.max_samples:
56
+ break
57
+
58
+ if self.is_foreground(
59
+ mask=mask,
60
+ x=x_y[i][0],
61
+ y=x_y[i][1],
62
+ width=width,
63
+ height=height,
64
+ min_foreground_ratio=self.min_foreground_ratio,
65
+ ):
66
+ count += 1
67
+ yield x_y[i]
68
+
69
+ def is_foreground(
70
+ self,
71
+ mask: Mask,
72
+ x: int,
73
+ y: int,
74
+ width: int,
75
+ height: int,
76
+ min_foreground_ratio: float,
77
+ ) -> bool:
78
+ """Check if a patch contains sufficient foreground.
79
+
80
+ Args:
81
+ mask: The mask of the image.
82
+ x: The x-coordinate of the patch.
83
+ y: The y-coordinate of the patch.
84
+ width: The width of the patch.
85
+ height: The height of the patch.
86
+ min_foreground_ratio: The minimum amount of foreground in the patch.
87
+ """
88
+ x_, y_ = self._scale_coords(x, y, mask.scale_factors)
89
+ width_, height_ = self._scale_coords(width, height, mask.scale_factors)
90
+ patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
91
+ return patch_mask.sum() / patch_mask.size >= min_foreground_ratio
92
+
93
+ def _scale_coords(
94
+ self,
95
+ x: int,
96
+ y: int,
97
+ scale_factors: Tuple[float, float],
98
+ ) -> Tuple[int, int]:
99
+ return int(x / scale_factors[0]), int(y / scale_factors[1])
@@ -0,0 +1,47 @@
1
+ """Grid sampler."""
2
+
3
+ from typing import Generator, Tuple
4
+
5
+ from eva.vision.data.wsi.patching.samplers import _utils, base
6
+
7
+
8
+ class GridSampler(base.Sampler):
9
+ """Sample patches based on a grid.
10
+
11
+ Args:
12
+ max_samples: The maximum number of samples to return.
13
+ overlap: The overlap between patches in the grid.
14
+ seed: The random seed.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ max_samples: int | None = None,
20
+ overlap: Tuple[int, int] = (0, 0),
21
+ seed: int = 42,
22
+ ):
23
+ """Initializes the sampler."""
24
+ self.max_samples = max_samples
25
+ self.overlap = overlap
26
+ self.seed = seed
27
+
28
+ def sample(
29
+ self,
30
+ width: int,
31
+ height: int,
32
+ layer_shape: Tuple[int, int],
33
+ ) -> Generator[Tuple[int, int], None, None]:
34
+ """Sample patches from a grid.
35
+
36
+ Args:
37
+ width: The width of the patches.
38
+ height: The height of the patches.
39
+ layer_shape: The shape of the layer.
40
+ """
41
+ _utils.validate_dimensions(width, height, layer_shape)
42
+ x_y, indices = _utils.get_grid_coords_and_indices(
43
+ layer_shape, width, height, self.overlap, seed=self.seed
44
+ )
45
+ max_samples = len(indices) if self.max_samples is None else self.max_samples
46
+ for i in indices[:max_samples]:
47
+ yield x_y[i]