kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,187 @@
1
+ """Dataset classes for whole-slide images."""
2
+
3
+ import bisect
4
+ import os
5
+ from typing import Callable, List
6
+
7
+ from loguru import logger
8
+ from torch.utils.data import dataset as torch_datasets
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms.v2 import functional
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import wsi
14
+ from eva.vision.data.datasets import vision
15
+ from eva.vision.data.wsi.patching import samplers
16
+
17
+
18
+ class WsiDataset(vision.VisionDataset):
19
+ """Dataset class for reading patches from whole-slide images."""
20
+
21
+ def __init__(
22
+ self,
23
+ file_path: str,
24
+ width: int,
25
+ height: int,
26
+ sampler: samplers.Sampler,
27
+ target_mpp: float,
28
+ overwrite_mpp: float | None = None,
29
+ backend: str = "openslide",
30
+ image_transforms: Callable | None = None,
31
+ ):
32
+ """Initializes a new dataset instance.
33
+
34
+ Args:
35
+ file_path: Path to the whole-slide image file.
36
+ width: Width of the patches to be extracted, in pixels.
37
+ height: Height of the patches to be extracted, in pixels.
38
+ sampler: The sampler to use for sampling patch coordinates.
39
+ target_mpp: Target microns per pixel (mpp) for the patches.
40
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
41
+ backend: The backend to use for reading the whole-slide images.
42
+ image_transforms: Transforms to apply to the extracted image patches.
43
+ """
44
+ super().__init__()
45
+
46
+ self._file_path = file_path
47
+ self._width = width
48
+ self._height = height
49
+ self._sampler = sampler
50
+ self._target_mpp = target_mpp
51
+ self._overwrite_mpp = overwrite_mpp
52
+ self._backend = backend
53
+ self._image_transforms = image_transforms
54
+
55
+ @override
56
+ def __len__(self):
57
+ return len(self._coords.x_y)
58
+
59
+ @override
60
+ def filename(self, index: int) -> str:
61
+ return f"{self._file_path}_{index}"
62
+
63
+ @property
64
+ def _wsi(self) -> wsi.Wsi:
65
+ return wsi.get_cached_wsi(self._file_path, self._backend, self._overwrite_mpp)
66
+
67
+ @property
68
+ def _coords(self) -> wsi.PatchCoordinates:
69
+ return wsi.get_cached_coords(
70
+ file_path=self._file_path,
71
+ width=self._width,
72
+ height=self._height,
73
+ target_mpp=self._target_mpp,
74
+ overwrite_mpp=self._overwrite_mpp,
75
+ sampler=self._sampler,
76
+ backend=self._backend,
77
+ )
78
+
79
+ @override
80
+ def __getitem__(self, index: int) -> tv_tensors.Image:
81
+ x, y = self._coords.x_y[index]
82
+ width, height, level_idx = self._coords.width, self._coords.height, self._coords.level_idx
83
+ patch = self._wsi.read_region((x, y), level_idx, (width, height))
84
+ patch = functional.to_image(patch)
85
+ patch = self._apply_transforms(patch)
86
+ return patch
87
+
88
+ def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
89
+ if self._image_transforms is not None:
90
+ image = self._image_transforms(image)
91
+ return image
92
+
93
+
94
+ class MultiWsiDataset(vision.VisionDataset):
95
+ """Dataset class for reading patches from multiple whole-slide images."""
96
+
97
+ def __init__(
98
+ self,
99
+ root: str,
100
+ file_paths: List[str],
101
+ width: int,
102
+ height: int,
103
+ sampler: samplers.Sampler,
104
+ target_mpp: float,
105
+ overwrite_mpp: float | None = None,
106
+ backend: str = "openslide",
107
+ image_transforms: Callable | None = None,
108
+ ):
109
+ """Initializes a new dataset instance.
110
+
111
+ Args:
112
+ root: Root directory of the dataset.
113
+ file_paths: List of paths to the whole-slide image files, relative to the root.
114
+ width: Width of the patches to be extracted, in pixels.
115
+ height: Height of the patches to be extracted, in pixels.
116
+ target_mpp: Target microns per pixel (mpp) for the patches.
117
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
118
+ sampler: The sampler to use for sampling patch coordinates.
119
+ backend: The backend to use for reading the whole-slide images.
120
+ image_transforms: Transforms to apply to the extracted image patches.
121
+ """
122
+ super().__init__()
123
+
124
+ self._root = root
125
+ self._file_paths = file_paths
126
+ self._width = width
127
+ self._height = height
128
+ self._target_mpp = target_mpp
129
+ self._overwrite_mpp = overwrite_mpp
130
+ self._sampler = sampler
131
+ self._backend = backend
132
+ self._image_transforms = image_transforms
133
+
134
+ self._concat_dataset: torch_datasets.ConcatDataset
135
+
136
+ @property
137
+ def datasets(self) -> List[WsiDataset]:
138
+ """Returns the list of WSI datasets."""
139
+ return self._concat_dataset.datasets # type: ignore
140
+
141
+ @property
142
+ def cumulative_sizes(self) -> List[int]:
143
+ """Returns the cumulative sizes of the WSI datasets."""
144
+ return self._concat_dataset.cumulative_sizes
145
+
146
+ @override
147
+ def configure(self) -> None:
148
+ self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
149
+
150
+ @override
151
+ def __len__(self) -> int:
152
+ return len(self._concat_dataset)
153
+
154
+ @override
155
+ def __getitem__(self, index: int) -> tv_tensors.Image:
156
+ return self._concat_dataset[index]
157
+
158
+ @override
159
+ def filename(self, index: int) -> str:
160
+ return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
161
+
162
+ def _load_datasets(self) -> list[WsiDataset]:
163
+ logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
164
+ wsi_datasets = []
165
+ for file_path in self._file_paths:
166
+ file_path = (
167
+ os.path.join(self._root, file_path) if self._root not in file_path else file_path
168
+ )
169
+ if not os.path.exists(file_path):
170
+ raise FileNotFoundError(f"File not found: {file_path}")
171
+
172
+ wsi_datasets.append(
173
+ WsiDataset(
174
+ file_path=file_path,
175
+ width=self._width,
176
+ height=self._height,
177
+ sampler=self._sampler,
178
+ target_mpp=self._target_mpp,
179
+ overwrite_mpp=self._overwrite_mpp,
180
+ backend=self._backend,
181
+ image_transforms=self._image_transforms,
182
+ )
183
+ )
184
+ return wsi_datasets
185
+
186
+ def _get_dataset_idx(self, index: int) -> int:
187
+ return bisect.bisect_right(self.cumulative_sizes, index)
@@ -1,5 +1,6 @@
1
1
  """Vision data transforms."""
2
2
 
3
- from eva.vision.data.transforms.common import ResizeAndCrop
3
+ from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
4
+ from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
4
5
 
5
- __all__ = ["ResizeAndCrop"]
6
+ __all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
@@ -1,5 +1,6 @@
1
1
  """Common vision transforms."""
2
2
 
3
+ from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
3
4
  from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
4
5
 
5
- __all__ = ["ResizeAndCrop"]
6
+ __all__ = ["ResizeAndClamp", "ResizeAndCrop"]
@@ -0,0 +1,51 @@
1
+ """Specialized transforms for resizing, clamping and range normalizing."""
2
+
3
+ from typing import Callable, Sequence, Tuple
4
+
5
+ from torchvision.transforms import v2
6
+
7
+ from eva.vision.data.transforms import normalization
8
+
9
+
10
+ class ResizeAndClamp(v2.Compose):
11
+ """Resizes, crops, clamps and normalizes an input image."""
12
+
13
+ def __init__(
14
+ self,
15
+ size: int | Sequence[int] = 224,
16
+ clamp_range: Tuple[int, int] = (-1024, 1024),
17
+ mean: Sequence[float] = (0.0, 0.0, 0.0),
18
+ std: Sequence[float] = (1.0, 1.0, 1.0),
19
+ ) -> None:
20
+ """Initializes the transform object.
21
+
22
+ Args:
23
+ size: Desired output size of the crop. If size is an `int` instead
24
+ of sequence like (h, w), a square crop (size, size) is made.
25
+ clamp_range: The lower and upper bound to clamp the pixel values.
26
+ mean: Sequence of means for each image channel.
27
+ std: Sequence of standard deviations for each image channel.
28
+ """
29
+ self._size = size
30
+ self._clamp_range = clamp_range
31
+ self._mean = mean
32
+ self._std = std
33
+
34
+ super().__init__(transforms=self._build_transforms())
35
+
36
+ def _build_transforms(self) -> Sequence[Callable]:
37
+ """Builds and returns the list of transforms."""
38
+ transforms = [
39
+ v2.Resize(size=self._size),
40
+ v2.CenterCrop(size=self._size),
41
+ normalization.Clamp(out_range=self._clamp_range),
42
+ normalization.RescaleIntensity(
43
+ in_range=self._clamp_range,
44
+ out_range=(0.0, 1.0),
45
+ ),
46
+ v2.Normalize(
47
+ mean=self._mean,
48
+ std=self._std,
49
+ ),
50
+ ]
51
+ return transforms
@@ -3,10 +3,10 @@
3
3
  from typing import Callable, Sequence
4
4
 
5
5
  import torch
6
- import torchvision.transforms.v2 as torch_transforms
6
+ from torchvision.transforms import v2
7
7
 
8
8
 
9
- class ResizeAndCrop(torch_transforms.Compose):
9
+ class ResizeAndCrop(v2.Compose):
10
10
  """Resizes, crops and normalizes an input image while preserving its aspect ratio."""
11
11
 
12
12
  def __init__(
@@ -32,11 +32,10 @@ class ResizeAndCrop(torch_transforms.Compose):
32
32
  def _build_transforms(self) -> Sequence[Callable]:
33
33
  """Builds and returns the list of transforms."""
34
34
  transforms = [
35
- torch_transforms.ToImage(),
36
- torch_transforms.Resize(size=self._size),
37
- torch_transforms.CenterCrop(size=self._size),
38
- torch_transforms.ToDtype(torch.float32, scale=True),
39
- torch_transforms.Normalize(
35
+ v2.Resize(size=self._size),
36
+ v2.CenterCrop(size=self._size),
37
+ v2.ToDtype(torch.float32, scale=True),
38
+ v2.Normalize(
40
39
  mean=self._mean,
41
40
  std=self._std,
42
41
  ),
@@ -0,0 +1,6 @@
1
+ """Normalization related transformations."""
2
+
3
+ from eva.vision.data.transforms.normalization.clamp import Clamp
4
+ from eva.vision.data.transforms.normalization.rescale_intensity import RescaleIntensity
5
+
6
+ __all__ = ["Clamp", "RescaleIntensity"]
@@ -0,0 +1,43 @@
1
+ """Image clamp transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ import torch
7
+ import torchvision.transforms.v2 as torch_transforms
8
+ from torchvision import tv_tensors
9
+ from typing_extensions import override
10
+
11
+
12
+ class Clamp(torch_transforms.Transform):
13
+ """Clamps all elements in input into a specific range."""
14
+
15
+ def __init__(self, out_range: Tuple[int, int]) -> None:
16
+ """Initializes the transform.
17
+
18
+ Args:
19
+ out_range: The lower and upper bound of the range to
20
+ be clamped to.
21
+ """
22
+ super().__init__()
23
+
24
+ self._out_range = out_range
25
+
26
+ @functools.singledispatchmethod
27
+ @override
28
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
+ return inpt
30
+
31
+ @_transform.register(torch.Tensor)
32
+ def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
33
+ return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
34
+
35
+ @_transform.register(tv_tensors.Image)
36
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
37
+ inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
38
+ return tv_tensors.wrap(inpt_clamp, like=inpt)
39
+
40
+ @_transform.register(tv_tensors.BoundingBoxes)
41
+ @_transform.register(tv_tensors.Mask)
42
+ def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
43
+ return inpt
@@ -0,0 +1,5 @@
1
+ """Functional normalization related transformations API."""
2
+
3
+ from eva.vision.data.transforms.normalization.functional.rescale_intensity import rescale_intensity
4
+
5
+ __all__ = ["rescale_intensity"]
@@ -0,0 +1,28 @@
1
+ """Intensity level functions."""
2
+
3
+ import sys
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+
9
+ def rescale_intensity(
10
+ image: torch.Tensor,
11
+ in_range: Tuple[float, float] | None = None,
12
+ out_range: Tuple[float, float] = (0.0, 1.0),
13
+ ) -> torch.Tensor:
14
+ """Stretches or shrinks the image intensity levels.
15
+
16
+ Args:
17
+ image: The image tensor as float-type.
18
+ in_range: The input data range. If `None`, it will
19
+ fetch the min and max of the input image.
20
+ out_range: The desired intensity range of the output.
21
+
22
+ Returns:
23
+ The image tensor after stretching or shrinking its intensity levels.
24
+ """
25
+ imin, imax = in_range or (image.min(), image.max())
26
+ omin, omax = out_range
27
+ image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
28
+ return image_scaled * (omax - omin) + omin
@@ -0,0 +1,53 @@
1
+ """Intensity level scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ import torch
7
+ import torchvision.transforms.v2 as torch_transforms
8
+ from torchvision import tv_tensors
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data.transforms.normalization import functional
12
+
13
+
14
+ class RescaleIntensity(torch_transforms.Transform):
15
+ """Stretches or shrinks the image intensity levels."""
16
+
17
+ def __init__(
18
+ self,
19
+ in_range: Tuple[float, float] | None = None,
20
+ out_range: Tuple[float, float] = (0.0, 1.0),
21
+ ) -> None:
22
+ """Initializes the transform.
23
+
24
+ Args:
25
+ in_range: The input data range. If `None`, it will
26
+ fetch the min and max of the input image.
27
+ out_range: The desired intensity range of the output.
28
+ """
29
+ super().__init__()
30
+
31
+ self._in_range = in_range
32
+ self._out_range = out_range
33
+
34
+ @functools.singledispatchmethod
35
+ @override
36
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
37
+ return inpt
38
+
39
+ @_transform.register(torch.Tensor)
40
+ def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
41
+ return functional.rescale_intensity(
42
+ inpt, in_range=self._in_range, out_range=self._out_range
43
+ )
44
+
45
+ @_transform.register(tv_tensors.Image)
46
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
47
+ scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range)
48
+ return tv_tensors.wrap(scaled_inpt, like=inpt)
49
+
50
+ @_transform.register(tv_tensors.BoundingBoxes)
51
+ @_transform.register(tv_tensors.Mask)
52
+ def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
53
+ return inpt
@@ -0,0 +1,16 @@
1
+ """WSI API."""
2
+
3
+ from eva.vision.data.wsi.backends import Wsi, get_cached_wsi, wsi_backend
4
+ from eva.vision.data.wsi.patching.coordinates import PatchCoordinates, get_cached_coords
5
+ from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level
6
+
7
+ __all__ = [
8
+ "Wsi",
9
+ "PatchCoordinates",
10
+ "Mask",
11
+ "get_cached_coords",
12
+ "wsi_backend",
13
+ "get_cached_wsi",
14
+ "get_mask",
15
+ "get_mask_level",
16
+ ]
@@ -0,0 +1,69 @@
1
+ """WSI Backends API."""
2
+
3
+ import functools
4
+ import importlib.util
5
+ from typing import Callable
6
+
7
+ from eva.vision.data.wsi.backends.base import Wsi
8
+
9
+ LRU_CACHE_SIZE = 32
10
+
11
+
12
+ def _is_openslide_available() -> bool:
13
+ """Whether the OpenSlide library is available."""
14
+ return importlib.util.find_spec("openslide") is not None
15
+
16
+
17
+ def _is_tiffslide_available() -> bool:
18
+ """Whether the TiffSlide library is available."""
19
+ return importlib.util.find_spec("tiffslide") is not None
20
+
21
+
22
+ def is_backend_available(backend: str) -> bool:
23
+ """Whether the specified backend is available."""
24
+ match backend:
25
+ case "openslide":
26
+ return _is_openslide_available()
27
+ case "tiffslide":
28
+ return _is_tiffslide_available()
29
+ return False
30
+
31
+
32
+ def wsi_backend(backend: str = "openslide") -> Callable[..., Wsi]:
33
+ """Returns the backend to use for reading the whole-slide images."""
34
+ match backend:
35
+ case "openslide":
36
+ if _is_openslide_available():
37
+ from eva.vision.data.wsi.backends.openslide import WsiOpenslide
38
+
39
+ return WsiOpenslide
40
+ else:
41
+ raise ValueError(
42
+ "Missing optional dependency: openslide.\n"
43
+ "Please install using `pip install openslide-python`."
44
+ )
45
+ case "tiffslide":
46
+ if _is_tiffslide_available():
47
+ from eva.vision.data.wsi.backends.tiffslide import WsiTiffslide
48
+
49
+ return WsiTiffslide
50
+ else:
51
+ raise ValueError(
52
+ "Missing optional dependency: tiffslide.\n"
53
+ "Please install using `pip install tiffslide`."
54
+ )
55
+ case "pil":
56
+ from eva.vision.data.wsi.backends.pil import PILImage
57
+
58
+ return PILImage
59
+ case _:
60
+ raise ValueError(f"Unknown WSI backend selected: {backend}")
61
+
62
+
63
+ @functools.lru_cache(LRU_CACHE_SIZE)
64
+ def get_cached_wsi(file_path: str, backend: str, overwrite_mpp: float | None = None) -> Wsi:
65
+ """Returns a cached instance of the whole-slide image backend reader."""
66
+ return wsi_backend(backend)(file_path, overwrite_mpp)
67
+
68
+
69
+ __all__ = ["Wsi", "wsi_backend", "get_cached_wsi", "_is_openslide_available"]
@@ -0,0 +1,115 @@
1
+ """Base Module for loading data from WSI files."""
2
+
3
+ import abc
4
+ from typing import Any, Sequence, Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ class Wsi(abc.ABC):
10
+ """Base class for loading data from Whole Slide Image (WSI) files."""
11
+
12
+ def __init__(self, file_path: str, overwrite_mpp: float | None = None):
13
+ """Initializes a Wsi object.
14
+
15
+ Args:
16
+ file_path: The path to the WSI file.
17
+ overwrite_mpp: The microns per pixel (mpp) value to use when missing in WSI metadata.
18
+ """
19
+ self._wsi = self.open_file(file_path)
20
+ self._overwrite_mpp = overwrite_mpp
21
+
22
+ @abc.abstractmethod
23
+ def open_file(self, file_path: str) -> Any:
24
+ """Opens the WSI file.
25
+
26
+ Args:
27
+ file_path: The path to the WSI file.
28
+ """
29
+
30
+ @property
31
+ @abc.abstractmethod
32
+ def level_dimensions(self) -> Sequence[Tuple[int, int]]:
33
+ """A list of (width, height) tuples for each level, from highest to lowest resolution."""
34
+
35
+ @property
36
+ @abc.abstractmethod
37
+ def level_downsamples(self) -> Sequence[float]:
38
+ """A list of downsampling factors for each level, relative to the highest resolution."""
39
+
40
+ @property
41
+ @abc.abstractmethod
42
+ def mpp(self) -> float:
43
+ """Microns per pixel at the highest resolution (level 0)."""
44
+
45
+ @abc.abstractmethod
46
+ def _read_region(
47
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
48
+ ) -> np.ndarray:
49
+ """Abstract method to read a region at a specified zoom level."""
50
+
51
+ def read_region(
52
+ self, location: Tuple[int, int], level: int, size: Tuple[int, int]
53
+ ) -> np.ndarray:
54
+ """Reads and returns image data for a specified region and zoom level.
55
+
56
+ Args:
57
+ location: Top-left corner (x, y) to start reading at level 0.
58
+ level: WSI level to read from.
59
+ size: Region size as (width, height) in pixels at the selected read level.
60
+ Remember to scale the size correctly.
61
+ """
62
+ self._verify_location(location, size)
63
+ data = self._read_region(location, level, size)
64
+ return self._read_postprocess(data)
65
+
66
+ def get_closest_level(self, target_mpp: float) -> int:
67
+ """Calculate the slide level that is closest to the target mpp.
68
+
69
+ Args:
70
+ slide: The whole-slide image object.
71
+ target_mpp: The target microns per pixel (mpp) value.
72
+ """
73
+ # Calculate the mpp for each level
74
+ level_mpps = self.mpp * np.array(self.level_downsamples)
75
+
76
+ # Ignore levels with higher mpp
77
+ level_mpps_filtered = level_mpps.copy()
78
+ level_mpps_filtered[level_mpps_filtered > target_mpp] = 0
79
+
80
+ if level_mpps_filtered.max() == 0:
81
+ # When all levels have higher mpp than target_mpp return the level with lowest mpp
82
+ level_idx = np.argmin(level_mpps)
83
+ else:
84
+ level_idx = np.argmax(level_mpps_filtered)
85
+
86
+ return int(level_idx)
87
+
88
+ def _verify_location(self, location: Tuple[int, int], size: Tuple[int, int]) -> None:
89
+ """Verifies that the requested region is within the slide dimensions.
90
+
91
+ Args:
92
+ location: Top-left corner (x, y) to start reading at level 0.
93
+ size: Region size as (width, height) in pixels at the selected read level.
94
+ """
95
+ x_max, y_max = self.level_dimensions[0]
96
+ x_scale = x_max / self.level_dimensions[0][0]
97
+ y_scale = y_max / self.level_dimensions[0][1]
98
+
99
+ if (
100
+ int(location[0] + x_scale * size[0]) > x_max
101
+ or int(location[1] + y_scale * size[1]) > y_max
102
+ ):
103
+ raise ValueError(f"Out of bounds region: {location}, {size}")
104
+
105
+ def _read_postprocess(self, data: np.ndarray) -> np.ndarray:
106
+ """Post-processes the read region data.
107
+
108
+ Args:
109
+ data: The read region data as a numpy array of shape (height, width, channels).
110
+ """
111
+ # Change color to white where the alpha channel is 0
112
+ if data.shape[2] == 4:
113
+ data[data[:, :, 3] == 0] = 255
114
+
115
+ return data[:, :, :3]