kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from torchvision.datasets import folder, utils
8
9
  from typing_extensions import override
9
10
 
@@ -52,8 +53,7 @@ class BACH(base.ImageClassification):
52
53
  root: str,
53
54
  split: Literal["train", "val"] | None = None,
54
55
  download: bool = False,
55
- image_transforms: Callable | None = None,
56
- target_transforms: Callable | None = None,
56
+ transforms: Callable | None = None,
57
57
  ) -> None:
58
58
  """Initialize the dataset.
59
59
 
@@ -68,15 +68,10 @@ class BACH(base.ImageClassification):
68
68
  Note that the download will be executed only by additionally
69
69
  calling the :meth:`prepare_data` method and if the data does
70
70
  not yet exist on disk.
71
- image_transforms: A function/transform that takes in an image
72
- and returns a transformed version.
73
- target_transforms: A function/transform that takes in the target
74
- and transforms it.
71
+ transforms: A function/transform which returns a transformed
72
+ version of the raw data samples.
75
73
  """
76
- super().__init__(
77
- image_transforms=image_transforms,
78
- target_transforms=target_transforms,
79
- )
74
+ super().__init__(transforms=transforms)
80
75
 
81
76
  self._root = root
82
77
  self._split = split
@@ -130,14 +125,14 @@ class BACH(base.ImageClassification):
130
125
  )
131
126
 
132
127
  @override
133
- def load_image(self, index: int) -> np.ndarray:
128
+ def load_image(self, index: int) -> tv_tensors.Image:
134
129
  image_path, _ = self._samples[self._indices[index]]
135
- return io.read_image(image_path)
130
+ return io.read_image_as_tensor(image_path)
136
131
 
137
132
  @override
138
- def load_target(self, index: int) -> np.ndarray:
133
+ def load_target(self, index: int) -> torch.Tensor:
139
134
  _, target = self._samples[self._indices[index]]
140
- return np.asarray(target, dtype=np.int64)
135
+ return torch.tensor(target, dtype=torch.long)
141
136
 
142
137
  @override
143
138
  def __len__(self) -> int:
@@ -3,32 +3,29 @@
3
3
  import abc
4
4
  from typing import Any, Callable, Dict, List, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.vision.data.datasets import vision
10
11
 
11
12
 
12
- class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
13
+ class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
13
14
  """Image classification abstract dataset."""
14
15
 
15
16
  def __init__(
16
17
  self,
17
- image_transforms: Callable | None = None,
18
- target_transforms: Callable | None = None,
18
+ transforms: Callable | None = None,
19
19
  ) -> None:
20
20
  """Initializes the image classification dataset.
21
21
 
22
22
  Args:
23
- image_transforms: A function/transform that takes in an image
24
- and returns a transformed version.
25
- target_transforms: A function/transform that takes in the target
26
- and transforms it.
23
+ transforms: A function/transform which returns a transformed
24
+ version of the raw data samples.
27
25
  """
28
26
  super().__init__()
29
27
 
30
- self._image_transforms = image_transforms
31
- self._target_transforms = target_transforms
28
+ self._transforms = transforms
32
29
 
33
30
  @property
34
31
  def classes(self) -> List[str] | None:
@@ -38,19 +35,18 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
38
35
  def class_to_idx(self) -> Dict[str, int] | None:
39
36
  """Returns a mapping of the class name to its target index."""
40
37
 
41
- def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
38
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
42
39
  """Returns the dataset metadata.
43
40
 
44
41
  Args:
45
42
  index: The index of the data sample to return the metadata of.
46
- If `None`, it will return the metadata of the current dataset.
47
43
 
48
44
  Returns:
49
45
  The sample metadata.
50
46
  """
51
47
 
52
48
  @abc.abstractmethod
53
- def load_image(self, index: int) -> np.ndarray:
49
+ def load_image(self, index: int) -> tv_tensors.Image:
54
50
  """Returns the `index`'th image sample.
55
51
 
56
52
  Args:
@@ -61,7 +57,7 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
61
57
  """
62
58
 
63
59
  @abc.abstractmethod
64
- def load_target(self, index: int) -> np.ndarray:
60
+ def load_target(self, index: int) -> torch.Tensor:
65
61
  """Returns the `index`'th target sample.
66
62
 
67
63
  Args:
@@ -77,14 +73,15 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
77
73
  raise NotImplementedError
78
74
 
79
75
  @override
80
- def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
76
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
81
77
  image = self.load_image(index)
82
78
  target = self.load_target(index)
83
- return self._apply_transforms(image, target)
79
+ image, target = self._apply_transforms(image, target)
80
+ return image, target, self.load_metadata(index) or {}
84
81
 
85
82
  def _apply_transforms(
86
- self, image: np.ndarray, target: np.ndarray
87
- ) -> Tuple[np.ndarray, np.ndarray]:
83
+ self, image: tv_tensors.Image, target: torch.Tensor
84
+ ) -> Tuple[tv_tensors.Image, torch.Tensor]:
88
85
  """Applies the transforms to the provided data and returns them.
89
86
 
90
87
  Args:
@@ -94,10 +91,6 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
94
91
  Returns:
95
92
  A tuple with the image and the target transformed.
96
93
  """
97
- if self._image_transforms is not None:
98
- image = self._image_transforms(image)
99
-
100
- if self._target_transforms is not None:
101
- target = self._target_transforms(target)
102
-
94
+ if self._transforms is not None:
95
+ image, target = self._transforms(image, target)
103
96
  return image, target
@@ -0,0 +1,244 @@
1
+ """Camelyon16 dataset class."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import pandas as pd
9
+ import torch
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms.v2 import functional
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data.datasets import _validators, wsi
15
+ from eva.vision.data.datasets.classification import base
16
+ from eva.vision.data.wsi.patching import samplers
17
+
18
+
19
+ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
20
+ """Dataset class for Camelyon16 images and corresponding targets."""
21
+
22
+ _val_slides = [
23
+ "normal_010",
24
+ "normal_013",
25
+ "normal_016",
26
+ "normal_017",
27
+ "normal_019",
28
+ "normal_020",
29
+ "normal_025",
30
+ "normal_030",
31
+ "normal_031",
32
+ "normal_032",
33
+ "normal_052",
34
+ "normal_056",
35
+ "normal_057",
36
+ "normal_067",
37
+ "normal_076",
38
+ "normal_079",
39
+ "normal_085",
40
+ "normal_095",
41
+ "normal_098",
42
+ "normal_099",
43
+ "normal_101",
44
+ "normal_102",
45
+ "normal_105",
46
+ "normal_106",
47
+ "normal_109",
48
+ "normal_129",
49
+ "normal_132",
50
+ "normal_137",
51
+ "normal_142",
52
+ "normal_143",
53
+ "normal_148",
54
+ "normal_152",
55
+ "tumor_001",
56
+ "tumor_005",
57
+ "tumor_011",
58
+ "tumor_012",
59
+ "tumor_013",
60
+ "tumor_019",
61
+ "tumor_031",
62
+ "tumor_037",
63
+ "tumor_043",
64
+ "tumor_046",
65
+ "tumor_057",
66
+ "tumor_065",
67
+ "tumor_069",
68
+ "tumor_071",
69
+ "tumor_073",
70
+ "tumor_079",
71
+ "tumor_080",
72
+ "tumor_081",
73
+ "tumor_082",
74
+ "tumor_085",
75
+ "tumor_097",
76
+ "tumor_109",
77
+ ]
78
+ """Validation slide names, same as the ones in patch camelyon."""
79
+
80
+ def __init__(
81
+ self,
82
+ root: str,
83
+ sampler: samplers.Sampler,
84
+ split: Literal["train", "val", "test"] | None = None,
85
+ width: int = 224,
86
+ height: int = 224,
87
+ target_mpp: float = 0.5,
88
+ backend: str = "openslide",
89
+ image_transforms: Callable | None = None,
90
+ seed: int = 42,
91
+ ) -> None:
92
+ """Initializes the dataset.
93
+
94
+ Args:
95
+ root: Root directory of the dataset.
96
+ sampler: The sampler to use for sampling patch coordinates.
97
+ split: Dataset split to use. If `None`, the entire dataset is used.
98
+ width: Width of the patches to be extracted, in pixels.
99
+ height: Height of the patches to be extracted, in pixels.
100
+ target_mpp: Target microns per pixel (mpp) for the patches.
101
+ backend: The backend to use for reading the whole-slide images.
102
+ image_transforms: Transforms to apply to the extracted image patches.
103
+ seed: Random seed for reproducibility.
104
+ """
105
+ self._split = split
106
+ self._root = root
107
+ self._width = width
108
+ self._height = height
109
+ self._target_mpp = target_mpp
110
+ self._seed = seed
111
+
112
+ wsi.MultiWsiDataset.__init__(
113
+ self,
114
+ root=root,
115
+ file_paths=self._load_file_paths(split),
116
+ width=width,
117
+ height=height,
118
+ sampler=sampler,
119
+ target_mpp=target_mpp,
120
+ backend=backend,
121
+ image_transforms=image_transforms,
122
+ )
123
+
124
+ @property
125
+ @override
126
+ def classes(self) -> List[str]:
127
+ return ["normal", "tumor"]
128
+
129
+ @property
130
+ @override
131
+ def class_to_idx(self) -> Dict[str, int]:
132
+ return {"normal": 0, "tumor": 1}
133
+
134
+ @functools.cached_property
135
+ def annotations_test_set(self) -> Dict[str, str]:
136
+ """Loads the dataset labels."""
137
+ path = os.path.join(self._root, "testing/reference.csv")
138
+ reference_df = pd.read_csv(path, header=None)
139
+ return {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
140
+
141
+ @functools.cached_property
142
+ def annotations(self) -> Dict[str, str]:
143
+ """Loads the dataset labels."""
144
+ annotations = {}
145
+ if self._split in ["test", None]:
146
+ path = os.path.join(self._root, "testing/reference.csv")
147
+ reference_df = pd.read_csv(path, header=None)
148
+ annotations.update(
149
+ {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
150
+ )
151
+
152
+ if self._split in ["train", "val", None]:
153
+ annotations.update(
154
+ {
155
+ self._get_id_from_path(file_path): self._get_class_from_path(file_path)
156
+ for file_path in self._file_paths
157
+ if "test" not in file_path
158
+ }
159
+ )
160
+ return annotations
161
+
162
+ @override
163
+ def prepare_data(self) -> None:
164
+ _validators.check_dataset_exists(self._root, False)
165
+
166
+ expected_directories = ["training/normal", "training/tumor", "testing/images"]
167
+ for resource in expected_directories:
168
+ if not os.path.isdir(os.path.join(self._root, resource)):
169
+ raise FileNotFoundError(f"'{resource}' not found in the root folder.")
170
+
171
+ if not os.path.isfile(os.path.join(self._root, "testing/reference.csv")):
172
+ raise FileNotFoundError("'reference.csv' file not found in the testing folder.")
173
+
174
+ @override
175
+ def validate(self) -> None:
176
+
177
+ expected_n_files = {
178
+ "train": 216,
179
+ "val": 54,
180
+ "test": 129,
181
+ None: 399,
182
+ }
183
+ _validators.check_number_of_files(
184
+ self._file_paths, expected_n_files[self._split], self._split
185
+ )
186
+ _validators.check_dataset_integrity(
187
+ self,
188
+ length=None,
189
+ n_classes=2,
190
+ first_and_last_labels=("normal", "tumor"),
191
+ )
192
+
193
+ @override
194
+ def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
195
+ return base.ImageClassification.__getitem__(self, index)
196
+
197
+ @override
198
+ def load_image(self, index: int) -> tv_tensors.Image:
199
+ image_array = wsi.MultiWsiDataset.__getitem__(self, index)
200
+ return functional.to_image(image_array)
201
+
202
+ @override
203
+ def load_target(self, index: int) -> torch.Tensor:
204
+ file_path = self._file_paths[self._get_dataset_idx(index)]
205
+ class_name = self.annotations[self._get_id_from_path(file_path)]
206
+ return torch.tensor(self.class_to_idx[class_name], dtype=torch.int64)
207
+
208
+ @override
209
+ def load_metadata(self, index: int) -> Dict[str, Any]:
210
+ return {"wsi_id": self.filename(index).split(".")[0]}
211
+
212
+ def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
213
+ """Loads the file paths of the corresponding dataset split."""
214
+ train_paths, val_paths = [], []
215
+ for path in glob.glob(os.path.join(self._root, "training/**/*.tif")):
216
+ if self._get_id_from_path(path) in self._val_slides:
217
+ val_paths.append(path)
218
+ else:
219
+ train_paths.append(path)
220
+ test_paths = glob.glob(os.path.join(self._root, "testing/images", "*.tif"))
221
+
222
+ match split:
223
+ case "train":
224
+ paths = train_paths
225
+ case "val":
226
+ paths = val_paths
227
+ case "test":
228
+ paths = test_paths
229
+ case None:
230
+ paths = train_paths + val_paths + test_paths
231
+ case _:
232
+ raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
233
+ return sorted([os.path.relpath(path, self._root) for path in paths])
234
+
235
+ def _get_id_from_path(self, file_path: str) -> str:
236
+ """Extracts the slide ID from the file path."""
237
+ return os.path.basename(file_path).replace(".tif", "")
238
+
239
+ def _get_class_from_path(self, file_path: str) -> str:
240
+ """Extracts the class name from the file path."""
241
+ class_name = self._get_id_from_path(file_path).split("_")[0]
242
+ if class_name not in self.classes:
243
+ raise ValueError(f"Invalid class name '{class_name}' in file path '{file_path}'.")
244
+ return class_name
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from torchvision.datasets import folder, utils
8
9
  from typing_extensions import override
9
10
 
@@ -37,8 +38,7 @@ class CRC(base.ImageClassification):
37
38
  root: str,
38
39
  split: Literal["train", "val"],
39
40
  download: bool = False,
40
- image_transforms: Callable | None = None,
41
- target_transforms: Callable | None = None,
41
+ transforms: Callable | None = None,
42
42
  ) -> None:
43
43
  """Initializes the dataset.
44
44
 
@@ -56,15 +56,10 @@ class CRC(base.ImageClassification):
56
56
  Note that the download will be executed only by additionally
57
57
  calling the :meth:`prepare_data` method and if the data does
58
58
  not yet exist on disk.
59
- image_transforms: A function/transform that takes in an image
60
- and returns a transformed version.
61
- target_transforms: A function/transform that takes in the target
62
- and transforms it.
59
+ transforms: A function/transform which returns a transformed
60
+ version of the raw data samples.
63
61
  """
64
- super().__init__(
65
- image_transforms=image_transforms,
66
- target_transforms=target_transforms,
67
- )
62
+ super().__init__(transforms=transforms)
68
63
 
69
64
  self._root = root
70
65
  self._split = split
@@ -122,14 +117,14 @@ class CRC(base.ImageClassification):
122
117
  )
123
118
 
124
119
  @override
125
- def load_image(self, index: int) -> np.ndarray:
120
+ def load_image(self, index: int) -> tv_tensors.Image:
126
121
  image_path, _ = self._samples[index]
127
- return io.read_image(image_path)
122
+ return io.read_image_as_tensor(image_path)
128
123
 
129
124
  @override
130
- def load_target(self, index: int) -> np.ndarray:
125
+ def load_target(self, index: int) -> torch.Tensor:
131
126
  _, target = self._samples[index]
132
- return np.asarray(target, dtype=np.int64)
127
+ return torch.tensor(target, dtype=torch.long)
133
128
 
134
129
  @override
135
130
  def __len__(self) -> int:
@@ -3,7 +3,8 @@
3
3
  import os
4
4
  from typing import Callable, Dict, List, Literal, Tuple
5
5
 
6
- import numpy as np
6
+ import torch
7
+ from torchvision import tv_tensors
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.vision.data.datasets import _validators
@@ -18,23 +19,17 @@ class MHIST(base.ImageClassification):
18
19
  self,
19
20
  root: str,
20
21
  split: Literal["train", "test"],
21
- image_transforms: Callable | None = None,
22
- target_transforms: Callable | None = None,
22
+ transforms: Callable | None = None,
23
23
  ) -> None:
24
24
  """Initialize the dataset.
25
25
 
26
26
  Args:
27
27
  root: Path to the root directory of the dataset.
28
28
  split: Dataset split to use.
29
- image_transforms: A function/transform that takes in an image
30
- and returns a transformed version.
31
- target_transforms: A function/transform that takes in the target
32
- and transforms it.
29
+ transforms: A function/transform which returns a transformed
30
+ version of the raw data samples.
33
31
  """
34
- super().__init__(
35
- image_transforms=image_transforms,
36
- target_transforms=target_transforms,
37
- )
32
+ super().__init__(transforms=transforms)
38
33
 
39
34
  self._root = root
40
35
  self._split = split
@@ -74,16 +69,16 @@ class MHIST(base.ImageClassification):
74
69
  )
75
70
 
76
71
  @override
77
- def load_image(self, index: int) -> np.ndarray:
72
+ def load_image(self, index: int) -> tv_tensors.Image:
78
73
  image_filename, _ = self._samples[index]
79
74
  image_path = os.path.join(self._dataset_path, image_filename)
80
- return io.read_image(image_path)
75
+ return io.read_image_as_tensor(image_path)
81
76
 
82
77
  @override
83
- def load_target(self, index: int) -> np.ndarray:
78
+ def load_target(self, index: int) -> torch.Tensor:
84
79
  _, label = self._samples[index]
85
80
  target = self.class_to_idx[label]
86
- return np.asarray(target, dtype=np.int64)
81
+ return torch.tensor(target, dtype=torch.float32)
87
82
 
88
83
  @override
89
84
  def __len__(self) -> int: