kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (85) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/trainers/_recorder.py +4 -1
  8. eva/core/utils/suppress_logs.py +28 -0
  9. eva/vision/data/__init__.py +2 -2
  10. eva/vision/data/dataloaders/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  12. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  13. eva/vision/data/datasets/__init__.py +2 -2
  14. eva/vision/data/datasets/classification/bach.py +3 -4
  15. eva/vision/data/datasets/classification/bracs.py +3 -4
  16. eva/vision/data/datasets/classification/breakhis.py +3 -4
  17. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  18. eva/vision/data/datasets/classification/crc.py +3 -4
  19. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  20. eva/vision/data/datasets/classification/mhist.py +3 -4
  21. eva/vision/data/datasets/classification/panda.py +4 -5
  22. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  23. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  24. eva/vision/data/datasets/classification/wsi.py +6 -5
  25. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  26. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  27. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  28. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  29. eva/vision/data/datasets/segmentation/consep.py +6 -7
  30. eva/vision/data/datasets/segmentation/lits.py +9 -8
  31. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  32. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  33. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  34. eva/vision/data/datasets/vision.py +95 -4
  35. eva/vision/data/datasets/wsi.py +5 -5
  36. eva/vision/data/transforms/__init__.py +22 -3
  37. eva/vision/data/transforms/common/__init__.py +1 -2
  38. eva/vision/data/transforms/croppad/__init__.py +11 -0
  39. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  40. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  41. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  42. eva/vision/data/transforms/intensity/__init__.py +11 -0
  43. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  44. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  45. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  46. eva/vision/data/transforms/spatial/__init__.py +7 -0
  47. eva/vision/data/transforms/spatial/flip.py +72 -0
  48. eva/vision/data/transforms/spatial/rotate.py +53 -0
  49. eva/vision/data/transforms/spatial/spacing.py +69 -0
  50. eva/vision/data/transforms/utility/__init__.py +5 -0
  51. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  52. eva/vision/data/tv_tensors/__init__.py +5 -0
  53. eva/vision/data/tv_tensors/volume.py +61 -0
  54. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  55. eva/vision/models/modules/semantic_segmentation.py +32 -19
  56. eva/vision/models/networks/backbones/__init__.py +9 -2
  57. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  58. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  59. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  60. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  61. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  62. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  63. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  64. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  65. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  66. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  67. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  68. eva/vision/utils/io/__init__.py +2 -0
  69. eva/vision/utils/io/nifti.py +91 -11
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
  72. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
  73. eva/vision/data/datasets/classification/base.py +0 -96
  74. eva/vision/data/datasets/segmentation/base.py +0 -96
  75. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  76. eva/vision/data/transforms/normalization/__init__.py +0 -6
  77. eva/vision/data/transforms/normalization/clamp.py +0 -43
  78. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  79. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  80. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  81. eva/vision/metrics/segmentation/BUILD +0 -1
  82. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  83. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
  85. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
1
  """Base dataset class."""
2
2
 
3
3
  import abc
4
+ from typing import Generic, TypeVar
4
5
 
5
6
  from eva.core.data.datasets import dataset
6
7
 
@@ -55,11 +56,15 @@ class Dataset(dataset.TorchDataset):
55
56
  """
56
57
 
57
58
 
58
- class MapDataset(Dataset):
59
+ DataSample = TypeVar("DataSample")
60
+ """The data sample type."""
61
+
62
+
63
+ class MapDataset(Dataset, abc.ABC, Generic[DataSample]):
59
64
  """Abstract base class for all map-style datasets."""
60
65
 
61
66
  @abc.abstractmethod
62
- def __getitem__(self, index: int):
67
+ def __getitem__(self, index: int) -> DataSample:
63
68
  """Retrieves the item at the given index.
64
69
 
65
70
  Args:
@@ -1,6 +1,6 @@
1
1
  """Neural Network Head Module."""
2
2
 
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, List
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
@@ -108,7 +108,9 @@ class HeadModule(module.ModelModule):
108
108
  return self._batch_step(batch)
109
109
 
110
110
  @override
111
- def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
111
+ def predict_step(
112
+ self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
113
+ ) -> torch.Tensor | List[torch.Tensor]:
112
114
  tensor = INPUT_BATCH(*batch).data
113
115
  return tensor if self.backbone is None else self.backbone(tensor)
114
116
 
@@ -1,6 +1,6 @@
1
1
  """Type annotations for model modules."""
2
2
 
3
- from typing import Any, Dict, NamedTuple
3
+ from typing import Any, Dict, List, NamedTuple
4
4
 
5
5
  import lightning.pytorch as pl
6
6
  import torch
@@ -13,7 +13,7 @@ MODEL_TYPE = nn.Module | pl.LightningModule
13
13
  class INPUT_BATCH(NamedTuple):
14
14
  """The default input batch data scheme."""
15
15
 
16
- data: torch.Tensor
16
+ data: torch.Tensor | List[torch.Tensor]
17
17
  """The data batch."""
18
18
 
19
19
  targets: torch.Tensor | None = None
@@ -1,6 +1,7 @@
1
1
  """Model outputs transforms API."""
2
2
 
3
+ from eva.core.models.transforms.as_discrete import AsDiscrete
3
4
  from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
4
5
  from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
5
6
 
6
- __all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
7
+ __all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
@@ -0,0 +1,57 @@
1
+ """Defines the AsDiscrete transformation."""
2
+
3
+ import torch
4
+
5
+
6
+ class AsDiscrete:
7
+ """Convert the logits tensor to discrete values."""
8
+
9
+ def __init__(
10
+ self,
11
+ argmax: bool = False,
12
+ to_onehot: int | bool | None = None,
13
+ threshold: float | None = None,
14
+ ) -> None:
15
+ """Convert the input tensor/array into discrete values.
16
+
17
+ Args:
18
+ argmax: Whether to execute argmax function on input data before transform.
19
+ to_onehot: if not None, convert input data into the one-hot format with
20
+ specified number of classes. If bool, it will try to infer the number
21
+ of classes.
22
+ threshold: If not None, threshold the float values to int number 0 or 1
23
+ with specified threshold.
24
+ """
25
+ super().__init__()
26
+
27
+ self._argmax = argmax
28
+ self._to_onehot = to_onehot
29
+ self._threshold = threshold
30
+
31
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
32
+ """Call method for the transformation."""
33
+ if self._argmax:
34
+ tensor = torch.argmax(tensor, dim=1, keepdim=True)
35
+
36
+ if self._to_onehot is not None:
37
+ tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)
38
+
39
+ if self._threshold is not None:
40
+ tensor = tensor >= self._threshold
41
+
42
+ return tensor
43
+
44
+
45
+ def _one_hot(
46
+ tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1
47
+ ) -> torch.Tensor:
48
+ """Convert input tensor into one-hot format (implementation taken from MONAI)."""
49
+ shape = list(tensor.shape)
50
+ if shape[dim] != 1:
51
+ raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.")
52
+
53
+ shape[dim] = num_classes
54
+ o = torch.zeros(size=shape, dtype=dtype, device=tensor.device)
55
+ tensor = o.scatter_(dim=dim, index=tensor.long(), value=1)
56
+
57
+ return tensor
@@ -1,8 +1,17 @@
1
1
  """Utilities and helper functions for models."""
2
2
 
3
+ import hashlib
4
+ import os
5
+ import sys
6
+ from typing import Any, Dict
7
+
8
+ import torch
9
+ from fsspec.core import url_to_fs
3
10
  from lightning_fabric.utilities import cloud_io
4
11
  from loguru import logger
5
- from torch import nn
12
+ from torch import hub, nn
13
+
14
+ from eva.core.utils.progress_bar import tqdm
6
15
 
7
16
 
8
17
  def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
@@ -23,3 +32,114 @@ def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
23
32
  model.load_state_dict(checkpoint, strict=True)
24
33
 
25
34
  logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
35
+
36
+
37
+ def load_state_dict_from_url(
38
+ url: str,
39
+ *,
40
+ model_dir: str | None = None,
41
+ filename: str | None = None,
42
+ progress: bool = True,
43
+ md5: str | None = None,
44
+ force: bool = False,
45
+ ) -> Dict[str, Any]:
46
+ """Loads the Torch serialized object at the given URL.
47
+
48
+ If the object is already present and valid in `model_dir`, it's
49
+ deserialized and returned.
50
+
51
+ The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
52
+ ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
53
+
54
+ Args:
55
+ url: URL of the object to download.
56
+ model_dir: Directory in which to save the object.
57
+ filename: Name for the downloaded file. Filename from ``url`` will be used if not set.
58
+ progress: Whether or not to display a progress bar to stderr.
59
+ md5: MD5 file code to check whether the file is valid. If not, it will re-download it.
60
+ force: Whether to download the file regardless if it exists.
61
+ """
62
+ model_dir = model_dir or os.path.join(hub.get_dir(), "checkpoints")
63
+ os.makedirs(model_dir, exist_ok=True)
64
+
65
+ cached_file = os.path.join(model_dir, filename or os.path.basename(url))
66
+ if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
67
+ sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
68
+ _download_url_to_file(url, cached_file, progress=progress)
69
+ if md5 is None or not _check_integrity(cached_file, md5):
70
+ sys.stderr.write(f"File MD5: {_calculate_md5(cached_file)}\n")
71
+
72
+ return torch.load(cached_file, map_location="cpu")
73
+
74
+
75
+ def _download_url_to_file(
76
+ url: str,
77
+ dst: str,
78
+ *,
79
+ progress: bool = True,
80
+ ) -> None:
81
+ """Download object at the given URL to a local path.
82
+
83
+ Args:
84
+ url: URL of the object to download.
85
+ dst: Full path where object will be saved.
86
+ chunk_size: The size of each chunk to read in bytes.
87
+ progress: Whether or not to display a progress bar to stderr.
88
+ """
89
+ try:
90
+ _download_with_fsspec(url=url, dst=dst, progress=progress)
91
+ except Exception:
92
+ try:
93
+ hub.download_url_to_file(url=url, dst=dst, progress=progress)
94
+ except Exception as hub_e:
95
+ raise RuntimeError(
96
+ f"Failed to download file from {url} using both fsspec and hub."
97
+ ) from hub_e
98
+
99
+
100
+ def _download_with_fsspec(
101
+ url: str,
102
+ dst: str,
103
+ *,
104
+ chunk_size: int = 1024 * 1024,
105
+ progress: bool = True,
106
+ ) -> None:
107
+ """Download object at the given URL to a local path using fsspec.
108
+
109
+ Args:
110
+ url: URL of the object to download.
111
+ dst: Full path where object will be saved.
112
+ chunk_size: The size of each chunk to read in bytes.
113
+ progress: Whether or not to display a progress bar to stderr.
114
+ """
115
+ filesystem, _ = url_to_fs(url, anon=False)
116
+ total_size_bytes = filesystem.size(url)
117
+ with (
118
+ filesystem.open(url, "rb") as remote_file,
119
+ tqdm(
120
+ total=total_size_bytes,
121
+ unit="iB",
122
+ unit_scale=True,
123
+ unit_divisor=1024,
124
+ disable=not progress,
125
+ ) as pbar,
126
+ ):
127
+ with open(dst, "wb") as local_file:
128
+ while True:
129
+ data = remote_file.read(chunk_size)
130
+ if not data:
131
+ break
132
+
133
+ local_file.write(data)
134
+ pbar.update(chunk_size)
135
+
136
+
137
+ def _calculate_md5(path: str) -> str:
138
+ """Calculate the md5 hash of a file."""
139
+ with open(path, "rb") as file:
140
+ return hashlib.md5(file.read(), usedforsecurity=False).hexdigest()
141
+
142
+
143
+ def _check_integrity(path: str, md5: str | None) -> bool:
144
+ """Check if the file matches the specified md5 hash."""
145
+ return (md5 is None) or (md5 == _calculate_md5(path))
@@ -129,7 +129,10 @@ class SessionRecorder:
129
129
  def _save_config(self) -> None:
130
130
  """Saves the config yaml with resolved env placeholders to the output directory."""
131
131
  if self.config_path:
132
- config = OmegaConf.load(self.config_path)
132
+ config_fs = cloud_io.get_filesystem(self.config_path)
133
+ with config_fs.open(self.config_path, "r") as config_file:
134
+ config = OmegaConf.load(config_file) # type: ignore
135
+
133
136
  fs = cloud_io.get_filesystem(self._output_dir, anon=False)
134
137
  with fs.open(os.path.join(self._output_dir, self._config_file), "w") as file:
135
138
  config_yaml = OmegaConf.to_yaml(config, resolve=True)
@@ -0,0 +1,28 @@
1
+ """Context manager to temporarily suppress all logging outputs."""
2
+
3
+ import logging
4
+ import sys
5
+ from types import TracebackType
6
+ from typing import Type
7
+
8
+
9
+ class SuppressLogs:
10
+ """Context manager to suppress all logs but print exceptions if they occur."""
11
+
12
+ def __enter__(self) -> None:
13
+ """Temporarily increase log level to suppress all logs."""
14
+ self._logger = logging.getLogger()
15
+ self._previous_level = self._logger.level
16
+ self._logger.setLevel(logging.CRITICAL + 1)
17
+
18
+ def __exit__(
19
+ self,
20
+ exc_type: Type[BaseException] | None,
21
+ exc_value: BaseException | None,
22
+ traceback: TracebackType | None,
23
+ ) -> bool:
24
+ """Restores the previous logging level and print exceptions."""
25
+ self._logger.setLevel(self._previous_level)
26
+ if exc_value:
27
+ print(f"Error: {exc_value}", file=sys.stderr)
28
+ return False
@@ -1,5 +1,5 @@
1
1
  """Vision data API."""
2
2
 
3
- from eva.vision.data import datasets, transforms
3
+ from eva.vision.data import datasets, transforms, tv_tensors
4
4
 
5
- __all__ = ["datasets", "transforms"]
5
+ __all__ = ["datasets", "transforms", "tv_tensors"]
@@ -0,0 +1,5 @@
1
+ """Dataloader related utilities and functions."""
2
+
3
+ from eva.vision.data.dataloaders import collate_fn
4
+
5
+ __all__ = ["collate_fn"]
@@ -0,0 +1,5 @@
1
+ """Dataloader collate API."""
2
+
3
+ from eva.vision.data.dataloaders.collate_fn.collection import collection_collate
4
+
5
+ __all__ = ["collection_collate"]
@@ -0,0 +1,22 @@
1
+ """Data only collate filter function."""
2
+
3
+ from typing import Any, List
4
+
5
+ import torch
6
+
7
+ from eva.core.models.modules.typings import INPUT_BATCH
8
+
9
+
10
+ def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any:
11
+ """Collate function for stacking a collection of data samples.
12
+
13
+ Args:
14
+ batch: The batch to be collated.
15
+
16
+ Returns:
17
+ The collated batch.
18
+ """
19
+ tensors, targets, metadata = zip(*batch, strict=False)
20
+ batch_tensors = torch.cat(list(map(torch.stack, tensors)))
21
+ batch_targets = torch.cat(list(map(torch.stack, targets)))
22
+ return batch_tensors, batch_targets, metadata
@@ -16,9 +16,9 @@ from eva.vision.data.datasets.classification import (
16
16
  )
17
17
  from eva.vision.data.datasets.segmentation import (
18
18
  BCSS,
19
+ BTCV,
19
20
  CoNSeP,
20
21
  EmbeddingsSegmentationDataset,
21
- ImageSegmentation,
22
22
  LiTS,
23
23
  LiTSBalanced,
24
24
  MoNuSAC,
@@ -29,6 +29,7 @@ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
29
29
 
30
30
  __all__ = [
31
31
  "BACH",
32
+ "BTCV",
32
33
  "BCSS",
33
34
  "BreaKHis",
34
35
  "BRACS",
@@ -43,7 +44,6 @@ __all__ = [
43
44
  "WsiClassificationDataset",
44
45
  "CoNSeP",
45
46
  "EmbeddingsSegmentationDataset",
46
- "ImageSegmentation",
47
47
  "LiTS",
48
48
  "LiTSBalanced",
49
49
  "MoNuSAC",
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
8
8
  from torchvision.datasets import folder, utils
9
9
  from typing_extensions import override
10
10
 
11
- from eva.vision.data.datasets import _utils, _validators, structs
12
- from eva.vision.data.datasets.classification import base
11
+ from eva.vision.data.datasets import _utils, _validators, structs, vision
13
12
  from eva.vision.utils import io
14
13
 
15
14
 
16
- class BACH(base.ImageClassification):
15
+ class BACH(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
17
16
  """Dataset class for BACH images and corresponding targets."""
18
17
 
19
18
  _train_index_ranges: List[Tuple[int, int]] = [
@@ -125,7 +124,7 @@ class BACH(base.ImageClassification):
125
124
  )
126
125
 
127
126
  @override
128
- def load_image(self, index: int) -> tv_tensors.Image:
127
+ def load_data(self, index: int) -> tv_tensors.Image:
129
128
  image_path, _ = self._samples[self._indices[index]]
130
129
  return io.read_image_as_tensor(image_path)
131
130
 
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
8
8
  from torchvision.datasets import folder
9
9
  from typing_extensions import override
10
10
 
11
- from eva.vision.data.datasets import _validators
12
- from eva.vision.data.datasets.classification import base
11
+ from eva.vision.data.datasets import _validators, vision
13
12
  from eva.vision.utils import io
14
13
 
15
14
 
16
- class BRACS(base.ImageClassification):
15
+ class BRACS(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
17
16
  """Dataset class for BRACS images and corresponding targets."""
18
17
 
19
18
  _expected_dataset_lengths: Dict[str, int] = {
@@ -80,7 +79,7 @@ class BRACS(base.ImageClassification):
80
79
  )
81
80
 
82
81
  @override
83
- def load_image(self, index: int) -> tv_tensors.Image:
82
+ def load_data(self, index: int) -> tv_tensors.Image:
84
83
  image_path, _ = self._samples[index]
85
84
  return io.read_image_as_tensor(image_path)
86
85
 
@@ -10,12 +10,11 @@ from torchvision import tv_tensors
10
10
  from torchvision.datasets import utils
11
11
  from typing_extensions import override
12
12
 
13
- from eva.vision.data.datasets import _validators, structs
14
- from eva.vision.data.datasets.classification import base
13
+ from eva.vision.data.datasets import _validators, structs, vision
15
14
  from eva.vision.utils import io
16
15
 
17
16
 
18
- class BreaKHis(base.ImageClassification):
17
+ class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
19
18
  """Dataset class for BreaKHis images and corresponding targets."""
20
19
 
21
20
  _resources: List[structs.DownloadResource] = [
@@ -145,7 +144,7 @@ class BreaKHis(base.ImageClassification):
145
144
  )
146
145
 
147
146
  @override
148
- def load_image(self, index: int) -> tv_tensors.Image:
147
+ def load_data(self, index: int) -> tv_tensors.Image:
149
148
  image_path = self._image_files[self._indices[index]]
150
149
  return io.read_image_as_tensor(image_path)
151
150
 
@@ -11,12 +11,11 @@ from torchvision import tv_tensors
11
11
  from torchvision.transforms.v2 import functional
12
12
  from typing_extensions import override
13
13
 
14
- from eva.vision.data.datasets import _validators, wsi
15
- from eva.vision.data.datasets.classification import base
14
+ from eva.vision.data.datasets import _validators, vision, wsi
16
15
  from eva.vision.data.wsi.patching import samplers
17
16
 
18
17
 
19
- class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
18
+ class Camelyon16(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
20
19
  """Dataset class for Camelyon16 images and corresponding targets."""
21
20
 
22
21
  _val_slides = [
@@ -195,10 +194,10 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
195
194
 
196
195
  @override
197
196
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
198
- return base.ImageClassification.__getitem__(self, index)
197
+ return vision.VisionDataset.__getitem__(self, index)
199
198
 
200
199
  @override
201
- def load_image(self, index: int) -> tv_tensors.Image:
200
+ def load_data(self, index: int) -> tv_tensors.Image:
202
201
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
203
202
  return functional.to_image(image_array)
204
203
 
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
8
8
  from torchvision.datasets import folder, utils
9
9
  from typing_extensions import override
10
10
 
11
- from eva.vision.data.datasets import _validators, structs
12
- from eva.vision.data.datasets.classification import base
11
+ from eva.vision.data.datasets import _validators, structs, vision
13
12
  from eva.vision.utils import io
14
13
 
15
14
 
16
- class CRC(base.ImageClassification):
15
+ class CRC(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
17
16
  """Dataset class for CRC images and corresponding targets."""
18
17
 
19
18
  _train_resource: structs.DownloadResource = structs.DownloadResource(
@@ -117,7 +116,7 @@ class CRC(base.ImageClassification):
117
116
  )
118
117
 
119
118
  @override
120
- def load_image(self, index: int) -> tv_tensors.Image:
119
+ def load_data(self, index: int) -> tv_tensors.Image:
121
120
  image_path, _ = self._samples[index]
122
121
  return io.read_image_as_tensor(image_path)
123
122
 
@@ -12,12 +12,11 @@ from loguru import logger
12
12
  from torchvision import tv_tensors
13
13
  from typing_extensions import override
14
14
 
15
- from eva.vision.data.datasets import _validators
16
- from eva.vision.data.datasets.classification import base
15
+ from eva.vision.data.datasets import _validators, vision
17
16
  from eva.vision.utils import io
18
17
 
19
18
 
20
- class GleasonArvaniti(base.ImageClassification):
19
+ class GleasonArvaniti(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
21
20
  """Dataset class for GleasonArvaniti images and corresponding targets."""
22
21
 
23
22
  _expected_dataset_lengths: Dict[str | None, int] = {
@@ -121,7 +120,7 @@ class GleasonArvaniti(base.ImageClassification):
121
120
  )
122
121
 
123
122
  @override
124
- def load_image(self, index: int) -> tv_tensors.Image:
123
+ def load_data(self, index: int) -> tv_tensors.Image:
125
124
  image_path = self._image_files[self._indices[index]]
126
125
  return io.read_image_as_tensor(image_path)
127
126
 
@@ -7,12 +7,11 @@ import torch
7
7
  from torchvision import tv_tensors
8
8
  from typing_extensions import override
9
9
 
10
- from eva.vision.data.datasets import _validators
11
- from eva.vision.data.datasets.classification import base
10
+ from eva.vision.data.datasets import _validators, vision
12
11
  from eva.vision.utils import io
13
12
 
14
13
 
15
- class MHIST(base.ImageClassification):
14
+ class MHIST(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
16
15
  """MHIST dataset."""
17
16
 
18
17
  def __init__(
@@ -69,7 +68,7 @@ class MHIST(base.ImageClassification):
69
68
  )
70
69
 
71
70
  @override
72
- def load_image(self, index: int) -> tv_tensors.Image:
71
+ def load_data(self, index: int) -> tv_tensors.Image:
73
72
  image_filename, _ = self._samples[index]
74
73
  image_path = os.path.join(self._dataset_path, image_filename)
75
74
  return io.read_image_as_tensor(image_path)
@@ -13,12 +13,11 @@ from torchvision.transforms.v2 import functional
13
13
  from typing_extensions import override
14
14
 
15
15
  from eva.core.data import splitting
16
- from eva.vision.data.datasets import _validators, structs, wsi
17
- from eva.vision.data.datasets.classification import base
16
+ from eva.vision.data.datasets import _validators, structs, vision, wsi
18
17
  from eva.vision.data.wsi.patching import samplers
19
18
 
20
19
 
21
- class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
20
+ class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
22
21
  """Dataset class for PANDA images and corresponding targets."""
23
22
 
24
23
  _train_split_ratio: float = 0.7
@@ -121,10 +120,10 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
121
120
 
122
121
  @override
123
122
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
124
- return base.ImageClassification.__getitem__(self, index)
123
+ return vision.VisionDataset.__getitem__(self, index)
125
124
 
126
125
  @override
127
- def load_image(self, index: int) -> tv_tensors.Image:
126
+ def load_data(self, index: int) -> tv_tensors.Image:
128
127
  image_array = wsi.MultiWsiDataset.__getitem__(self, index)
129
128
  return functional.to_image(image_array)
130
129
 
@@ -10,14 +10,13 @@ from torchvision.datasets import utils
10
10
  from torchvision.transforms.v2 import functional
11
11
  from typing_extensions import override
12
12
 
13
- from eva.vision.data.datasets import _validators, structs
14
- from eva.vision.data.datasets.classification import base
13
+ from eva.vision.data.datasets import _validators, structs, vision
15
14
 
16
15
  _URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
17
16
  """PatchCamelyon URL files templates."""
18
17
 
19
18
 
20
- class PatchCamelyon(base.ImageClassification):
19
+ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
21
20
  """Dataset class for PatchCamelyon images and corresponding targets."""
22
21
 
23
22
  _train_resources: List[structs.DownloadResource] = [
@@ -127,7 +126,7 @@ class PatchCamelyon(base.ImageClassification):
127
126
  )
128
127
 
129
128
  @override
130
- def load_image(self, index: int) -> tv_tensors.Image:
129
+ def load_data(self, index: int) -> tv_tensors.Image:
131
130
  return self._load_from_h5("x", index)
132
131
 
133
132
  @override
@@ -10,12 +10,11 @@ import torch
10
10
  from torchvision import tv_tensors
11
11
  from typing_extensions import override
12
12
 
13
- from eva.vision.data.datasets import _validators
14
- from eva.vision.data.datasets.classification import base
13
+ from eva.vision.data.datasets import _validators, vision
15
14
  from eva.vision.utils import io
16
15
 
17
16
 
18
- class UniToPatho(base.ImageClassification):
17
+ class UniToPatho(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
19
18
  """Dataset class for UniToPatho images and corresponding targets."""
20
19
 
21
20
  _expected_dataset_lengths: Dict[str | None, int] = {
@@ -109,7 +108,7 @@ class UniToPatho(base.ImageClassification):
109
108
  )
110
109
 
111
110
  @override
112
- def load_image(self, index: int) -> tv_tensors.Image:
111
+ def load_data(self, index: int) -> tv_tensors.Image:
113
112
  image_path = self._image_files[self._indices[index]]
114
113
  return io.read_image_as_tensor(image_path)
115
114
 
@@ -9,12 +9,13 @@ import torch
9
9
  from torchvision import tv_tensors
10
10
  from typing_extensions import override
11
11
 
12
- from eva.vision.data.datasets import wsi
13
- from eva.vision.data.datasets.classification import base
12
+ from eva.vision.data.datasets import vision, wsi
14
13
  from eva.vision.data.wsi.patching import samplers
15
14
 
16
15
 
17
- class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
16
+ class WsiClassificationDataset(
17
+ wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]
18
+ ):
18
19
  """A general dataset class for whole-slide image classification using manifest files."""
19
20
 
20
21
  default_column_mapping: Dict[str, str] = {
@@ -78,10 +79,10 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
78
79
 
79
80
  @override
80
81
  def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
81
- return base.ImageClassification.__getitem__(self, index)
82
+ return vision.VisionDataset.__getitem__(self, index)
82
83
 
83
84
  @override
84
- def load_image(self, index: int) -> tv_tensors.Image:
85
+ def load_data(self, index: int) -> tv_tensors.Image:
85
86
  return wsi.MultiWsiDataset.__getitem__(self, index)
86
87
 
87
88
  @override