kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/datasets/base.py +7 -2
- eva/core/models/modules/head.py +4 -2
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/_recorder.py +4 -1
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +2 -2
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +3 -4
- eva/vision/data/datasets/classification/breakhis.py +3 -4
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +3 -4
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +32 -19
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
eva/core/data/datasets/base.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Base dataset class."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
from eva.core.data.datasets import dataset
|
|
6
7
|
|
|
@@ -55,11 +56,15 @@ class Dataset(dataset.TorchDataset):
|
|
|
55
56
|
"""
|
|
56
57
|
|
|
57
58
|
|
|
58
|
-
|
|
59
|
+
DataSample = TypeVar("DataSample")
|
|
60
|
+
"""The data sample type."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MapDataset(Dataset, abc.ABC, Generic[DataSample]):
|
|
59
64
|
"""Abstract base class for all map-style datasets."""
|
|
60
65
|
|
|
61
66
|
@abc.abstractmethod
|
|
62
|
-
def __getitem__(self, index: int):
|
|
67
|
+
def __getitem__(self, index: int) -> DataSample:
|
|
63
68
|
"""Retrieves the item at the given index.
|
|
64
69
|
|
|
65
70
|
Args:
|
eva/core/models/modules/head.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Neural Network Head Module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable, Dict
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
@@ -108,7 +108,9 @@ class HeadModule(module.ModelModule):
|
|
|
108
108
|
return self._batch_step(batch)
|
|
109
109
|
|
|
110
110
|
@override
|
|
111
|
-
def predict_step(
|
|
111
|
+
def predict_step(
|
|
112
|
+
self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
|
|
113
|
+
) -> torch.Tensor | List[torch.Tensor]:
|
|
112
114
|
tensor = INPUT_BATCH(*batch).data
|
|
113
115
|
return tensor if self.backbone is None else self.backbone(tensor)
|
|
114
116
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Type annotations for model modules."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, NamedTuple
|
|
3
|
+
from typing import Any, Dict, List, NamedTuple
|
|
4
4
|
|
|
5
5
|
import lightning.pytorch as pl
|
|
6
6
|
import torch
|
|
@@ -13,7 +13,7 @@ MODEL_TYPE = nn.Module | pl.LightningModule
|
|
|
13
13
|
class INPUT_BATCH(NamedTuple):
|
|
14
14
|
"""The default input batch data scheme."""
|
|
15
15
|
|
|
16
|
-
data: torch.Tensor
|
|
16
|
+
data: torch.Tensor | List[torch.Tensor]
|
|
17
17
|
"""The data batch."""
|
|
18
18
|
|
|
19
19
|
targets: torch.Tensor | None = None
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Model outputs transforms API."""
|
|
2
2
|
|
|
3
|
+
from eva.core.models.transforms.as_discrete import AsDiscrete
|
|
3
4
|
from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
|
|
4
5
|
from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures
|
|
5
6
|
|
|
6
|
-
__all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
|
|
7
|
+
__all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Defines the AsDiscrete transformation."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AsDiscrete:
|
|
7
|
+
"""Convert the logits tensor to discrete values."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
argmax: bool = False,
|
|
12
|
+
to_onehot: int | bool | None = None,
|
|
13
|
+
threshold: float | None = None,
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Convert the input tensor/array into discrete values.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
argmax: Whether to execute argmax function on input data before transform.
|
|
19
|
+
to_onehot: if not None, convert input data into the one-hot format with
|
|
20
|
+
specified number of classes. If bool, it will try to infer the number
|
|
21
|
+
of classes.
|
|
22
|
+
threshold: If not None, threshold the float values to int number 0 or 1
|
|
23
|
+
with specified threshold.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
self._argmax = argmax
|
|
28
|
+
self._to_onehot = to_onehot
|
|
29
|
+
self._threshold = threshold
|
|
30
|
+
|
|
31
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
"""Call method for the transformation."""
|
|
33
|
+
if self._argmax:
|
|
34
|
+
tensor = torch.argmax(tensor, dim=1, keepdim=True)
|
|
35
|
+
|
|
36
|
+
if self._to_onehot is not None:
|
|
37
|
+
tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)
|
|
38
|
+
|
|
39
|
+
if self._threshold is not None:
|
|
40
|
+
tensor = tensor >= self._threshold
|
|
41
|
+
|
|
42
|
+
return tensor
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _one_hot(
|
|
46
|
+
tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1
|
|
47
|
+
) -> torch.Tensor:
|
|
48
|
+
"""Convert input tensor into one-hot format (implementation taken from MONAI)."""
|
|
49
|
+
shape = list(tensor.shape)
|
|
50
|
+
if shape[dim] != 1:
|
|
51
|
+
raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.")
|
|
52
|
+
|
|
53
|
+
shape[dim] = num_classes
|
|
54
|
+
o = torch.zeros(size=shape, dtype=dtype, device=tensor.device)
|
|
55
|
+
tensor = o.scatter_(dim=dim, index=tensor.long(), value=1)
|
|
56
|
+
|
|
57
|
+
return tensor
|
|
@@ -1,8 +1,17 @@
|
|
|
1
1
|
"""Utilities and helper functions for models."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from fsspec.core import url_to_fs
|
|
3
10
|
from lightning_fabric.utilities import cloud_io
|
|
4
11
|
from loguru import logger
|
|
5
|
-
from torch import nn
|
|
12
|
+
from torch import hub, nn
|
|
13
|
+
|
|
14
|
+
from eva.core.utils.progress_bar import tqdm
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
|
|
@@ -23,3 +32,114 @@ def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
|
|
|
23
32
|
model.load_state_dict(checkpoint, strict=True)
|
|
24
33
|
|
|
25
34
|
logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_state_dict_from_url(
|
|
38
|
+
url: str,
|
|
39
|
+
*,
|
|
40
|
+
model_dir: str | None = None,
|
|
41
|
+
filename: str | None = None,
|
|
42
|
+
progress: bool = True,
|
|
43
|
+
md5: str | None = None,
|
|
44
|
+
force: bool = False,
|
|
45
|
+
) -> Dict[str, Any]:
|
|
46
|
+
"""Loads the Torch serialized object at the given URL.
|
|
47
|
+
|
|
48
|
+
If the object is already present and valid in `model_dir`, it's
|
|
49
|
+
deserialized and returned.
|
|
50
|
+
|
|
51
|
+
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
|
|
52
|
+
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
url: URL of the object to download.
|
|
56
|
+
model_dir: Directory in which to save the object.
|
|
57
|
+
filename: Name for the downloaded file. Filename from ``url`` will be used if not set.
|
|
58
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
59
|
+
md5: MD5 file code to check whether the file is valid. If not, it will re-download it.
|
|
60
|
+
force: Whether to download the file regardless if it exists.
|
|
61
|
+
"""
|
|
62
|
+
model_dir = model_dir or os.path.join(hub.get_dir(), "checkpoints")
|
|
63
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
64
|
+
|
|
65
|
+
cached_file = os.path.join(model_dir, filename or os.path.basename(url))
|
|
66
|
+
if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
|
|
67
|
+
sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
|
|
68
|
+
_download_url_to_file(url, cached_file, progress=progress)
|
|
69
|
+
if md5 is None or not _check_integrity(cached_file, md5):
|
|
70
|
+
sys.stderr.write(f"File MD5: {_calculate_md5(cached_file)}\n")
|
|
71
|
+
|
|
72
|
+
return torch.load(cached_file, map_location="cpu")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _download_url_to_file(
|
|
76
|
+
url: str,
|
|
77
|
+
dst: str,
|
|
78
|
+
*,
|
|
79
|
+
progress: bool = True,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Download object at the given URL to a local path.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
url: URL of the object to download.
|
|
85
|
+
dst: Full path where object will be saved.
|
|
86
|
+
chunk_size: The size of each chunk to read in bytes.
|
|
87
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
_download_with_fsspec(url=url, dst=dst, progress=progress)
|
|
91
|
+
except Exception:
|
|
92
|
+
try:
|
|
93
|
+
hub.download_url_to_file(url=url, dst=dst, progress=progress)
|
|
94
|
+
except Exception as hub_e:
|
|
95
|
+
raise RuntimeError(
|
|
96
|
+
f"Failed to download file from {url} using both fsspec and hub."
|
|
97
|
+
) from hub_e
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _download_with_fsspec(
|
|
101
|
+
url: str,
|
|
102
|
+
dst: str,
|
|
103
|
+
*,
|
|
104
|
+
chunk_size: int = 1024 * 1024,
|
|
105
|
+
progress: bool = True,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Download object at the given URL to a local path using fsspec.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
url: URL of the object to download.
|
|
111
|
+
dst: Full path where object will be saved.
|
|
112
|
+
chunk_size: The size of each chunk to read in bytes.
|
|
113
|
+
progress: Whether or not to display a progress bar to stderr.
|
|
114
|
+
"""
|
|
115
|
+
filesystem, _ = url_to_fs(url, anon=False)
|
|
116
|
+
total_size_bytes = filesystem.size(url)
|
|
117
|
+
with (
|
|
118
|
+
filesystem.open(url, "rb") as remote_file,
|
|
119
|
+
tqdm(
|
|
120
|
+
total=total_size_bytes,
|
|
121
|
+
unit="iB",
|
|
122
|
+
unit_scale=True,
|
|
123
|
+
unit_divisor=1024,
|
|
124
|
+
disable=not progress,
|
|
125
|
+
) as pbar,
|
|
126
|
+
):
|
|
127
|
+
with open(dst, "wb") as local_file:
|
|
128
|
+
while True:
|
|
129
|
+
data = remote_file.read(chunk_size)
|
|
130
|
+
if not data:
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
local_file.write(data)
|
|
134
|
+
pbar.update(chunk_size)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _calculate_md5(path: str) -> str:
|
|
138
|
+
"""Calculate the md5 hash of a file."""
|
|
139
|
+
with open(path, "rb") as file:
|
|
140
|
+
return hashlib.md5(file.read(), usedforsecurity=False).hexdigest()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _check_integrity(path: str, md5: str | None) -> bool:
|
|
144
|
+
"""Check if the file matches the specified md5 hash."""
|
|
145
|
+
return (md5 is None) or (md5 == _calculate_md5(path))
|
eva/core/trainers/_recorder.py
CHANGED
|
@@ -129,7 +129,10 @@ class SessionRecorder:
|
|
|
129
129
|
def _save_config(self) -> None:
|
|
130
130
|
"""Saves the config yaml with resolved env placeholders to the output directory."""
|
|
131
131
|
if self.config_path:
|
|
132
|
-
|
|
132
|
+
config_fs = cloud_io.get_filesystem(self.config_path)
|
|
133
|
+
with config_fs.open(self.config_path, "r") as config_file:
|
|
134
|
+
config = OmegaConf.load(config_file) # type: ignore
|
|
135
|
+
|
|
133
136
|
fs = cloud_io.get_filesystem(self._output_dir, anon=False)
|
|
134
137
|
with fs.open(os.path.join(self._output_dir, self._config_file), "w") as file:
|
|
135
138
|
config_yaml = OmegaConf.to_yaml(config, resolve=True)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Context manager to temporarily suppress all logging outputs."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import Type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SuppressLogs:
|
|
10
|
+
"""Context manager to suppress all logs but print exceptions if they occur."""
|
|
11
|
+
|
|
12
|
+
def __enter__(self) -> None:
|
|
13
|
+
"""Temporarily increase log level to suppress all logs."""
|
|
14
|
+
self._logger = logging.getLogger()
|
|
15
|
+
self._previous_level = self._logger.level
|
|
16
|
+
self._logger.setLevel(logging.CRITICAL + 1)
|
|
17
|
+
|
|
18
|
+
def __exit__(
|
|
19
|
+
self,
|
|
20
|
+
exc_type: Type[BaseException] | None,
|
|
21
|
+
exc_value: BaseException | None,
|
|
22
|
+
traceback: TracebackType | None,
|
|
23
|
+
) -> bool:
|
|
24
|
+
"""Restores the previous logging level and print exceptions."""
|
|
25
|
+
self._logger.setLevel(self._previous_level)
|
|
26
|
+
if exc_value:
|
|
27
|
+
print(f"Error: {exc_value}", file=sys.stderr)
|
|
28
|
+
return False
|
eva/vision/data/__init__.py
CHANGED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Data only collate filter function."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.models.modules.typings import INPUT_BATCH
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any:
|
|
11
|
+
"""Collate function for stacking a collection of data samples.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
batch: The batch to be collated.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The collated batch.
|
|
18
|
+
"""
|
|
19
|
+
tensors, targets, metadata = zip(*batch, strict=False)
|
|
20
|
+
batch_tensors = torch.cat(list(map(torch.stack, tensors)))
|
|
21
|
+
batch_targets = torch.cat(list(map(torch.stack, targets)))
|
|
22
|
+
return batch_tensors, batch_targets, metadata
|
|
@@ -16,9 +16,9 @@ from eva.vision.data.datasets.classification import (
|
|
|
16
16
|
)
|
|
17
17
|
from eva.vision.data.datasets.segmentation import (
|
|
18
18
|
BCSS,
|
|
19
|
+
BTCV,
|
|
19
20
|
CoNSeP,
|
|
20
21
|
EmbeddingsSegmentationDataset,
|
|
21
|
-
ImageSegmentation,
|
|
22
22
|
LiTS,
|
|
23
23
|
LiTSBalanced,
|
|
24
24
|
MoNuSAC,
|
|
@@ -29,6 +29,7 @@ from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
|
|
|
29
29
|
|
|
30
30
|
__all__ = [
|
|
31
31
|
"BACH",
|
|
32
|
+
"BTCV",
|
|
32
33
|
"BCSS",
|
|
33
34
|
"BreaKHis",
|
|
34
35
|
"BRACS",
|
|
@@ -43,7 +44,6 @@ __all__ = [
|
|
|
43
44
|
"WsiClassificationDataset",
|
|
44
45
|
"CoNSeP",
|
|
45
46
|
"EmbeddingsSegmentationDataset",
|
|
46
|
-
"ImageSegmentation",
|
|
47
47
|
"LiTS",
|
|
48
48
|
"LiTSBalanced",
|
|
49
49
|
"MoNuSAC",
|
|
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
|
|
|
8
8
|
from torchvision.datasets import folder, utils
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
-
from eva.vision.data.datasets import _utils, _validators, structs
|
|
12
|
-
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.data.datasets import _utils, _validators, structs, vision
|
|
13
12
|
from eva.vision.utils import io
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class BACH(
|
|
15
|
+
class BACH(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
17
16
|
"""Dataset class for BACH images and corresponding targets."""
|
|
18
17
|
|
|
19
18
|
_train_index_ranges: List[Tuple[int, int]] = [
|
|
@@ -125,7 +124,7 @@ class BACH(base.ImageClassification):
|
|
|
125
124
|
)
|
|
126
125
|
|
|
127
126
|
@override
|
|
128
|
-
def
|
|
127
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
129
128
|
image_path, _ = self._samples[self._indices[index]]
|
|
130
129
|
return io.read_image_as_tensor(image_path)
|
|
131
130
|
|
|
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
|
|
|
8
8
|
from torchvision.datasets import folder
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
-
from eva.vision.data.datasets import _validators
|
|
12
|
-
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.data.datasets import _validators, vision
|
|
13
12
|
from eva.vision.utils import io
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class BRACS(
|
|
15
|
+
class BRACS(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
17
16
|
"""Dataset class for BRACS images and corresponding targets."""
|
|
18
17
|
|
|
19
18
|
_expected_dataset_lengths: Dict[str, int] = {
|
|
@@ -80,7 +79,7 @@ class BRACS(base.ImageClassification):
|
|
|
80
79
|
)
|
|
81
80
|
|
|
82
81
|
@override
|
|
83
|
-
def
|
|
82
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
84
83
|
image_path, _ = self._samples[index]
|
|
85
84
|
return io.read_image_as_tensor(image_path)
|
|
86
85
|
|
|
@@ -10,12 +10,11 @@ from torchvision import tv_tensors
|
|
|
10
10
|
from torchvision.datasets import utils
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
from eva.vision.data.datasets import _validators, structs
|
|
14
|
-
from eva.vision.data.datasets.classification import base
|
|
13
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
15
14
|
from eva.vision.utils import io
|
|
16
15
|
|
|
17
16
|
|
|
18
|
-
class BreaKHis(
|
|
17
|
+
class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
19
18
|
"""Dataset class for BreaKHis images and corresponding targets."""
|
|
20
19
|
|
|
21
20
|
_resources: List[structs.DownloadResource] = [
|
|
@@ -145,7 +144,7 @@ class BreaKHis(base.ImageClassification):
|
|
|
145
144
|
)
|
|
146
145
|
|
|
147
146
|
@override
|
|
148
|
-
def
|
|
147
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
149
148
|
image_path = self._image_files[self._indices[index]]
|
|
150
149
|
return io.read_image_as_tensor(image_path)
|
|
151
150
|
|
|
@@ -11,12 +11,11 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
-
from eva.vision.data.datasets.classification import base
|
|
14
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
16
15
|
from eva.vision.data.wsi.patching import samplers
|
|
17
16
|
|
|
18
17
|
|
|
19
|
-
class Camelyon16(wsi.MultiWsiDataset,
|
|
18
|
+
class Camelyon16(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
20
19
|
"""Dataset class for Camelyon16 images and corresponding targets."""
|
|
21
20
|
|
|
22
21
|
_val_slides = [
|
|
@@ -195,10 +194,10 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
195
194
|
|
|
196
195
|
@override
|
|
197
196
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
198
|
-
return
|
|
197
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
199
198
|
|
|
200
199
|
@override
|
|
201
|
-
def
|
|
200
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
202
201
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
203
202
|
return functional.to_image(image_array)
|
|
204
203
|
|
|
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
|
|
|
8
8
|
from torchvision.datasets import folder, utils
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
-
from eva.vision.data.datasets import _validators, structs
|
|
12
|
-
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
13
12
|
from eva.vision.utils import io
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class CRC(
|
|
15
|
+
class CRC(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
17
16
|
"""Dataset class for CRC images and corresponding targets."""
|
|
18
17
|
|
|
19
18
|
_train_resource: structs.DownloadResource = structs.DownloadResource(
|
|
@@ -117,7 +116,7 @@ class CRC(base.ImageClassification):
|
|
|
117
116
|
)
|
|
118
117
|
|
|
119
118
|
@override
|
|
120
|
-
def
|
|
119
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
121
120
|
image_path, _ = self._samples[index]
|
|
122
121
|
return io.read_image_as_tensor(image_path)
|
|
123
122
|
|
|
@@ -12,12 +12,11 @@ from loguru import logger
|
|
|
12
12
|
from torchvision import tv_tensors
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
|
-
from eva.vision.data.datasets import _validators
|
|
16
|
-
from eva.vision.data.datasets.classification import base
|
|
15
|
+
from eva.vision.data.datasets import _validators, vision
|
|
17
16
|
from eva.vision.utils import io
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class GleasonArvaniti(
|
|
19
|
+
class GleasonArvaniti(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
21
20
|
"""Dataset class for GleasonArvaniti images and corresponding targets."""
|
|
22
21
|
|
|
23
22
|
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
@@ -121,7 +120,7 @@ class GleasonArvaniti(base.ImageClassification):
|
|
|
121
120
|
)
|
|
122
121
|
|
|
123
122
|
@override
|
|
124
|
-
def
|
|
123
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
125
124
|
image_path = self._image_files[self._indices[index]]
|
|
126
125
|
return io.read_image_as_tensor(image_path)
|
|
127
126
|
|
|
@@ -7,12 +7,11 @@ import torch
|
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
8
|
from typing_extensions import override
|
|
9
9
|
|
|
10
|
-
from eva.vision.data.datasets import _validators
|
|
11
|
-
from eva.vision.data.datasets.classification import base
|
|
10
|
+
from eva.vision.data.datasets import _validators, vision
|
|
12
11
|
from eva.vision.utils import io
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
class MHIST(
|
|
14
|
+
class MHIST(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
16
15
|
"""MHIST dataset."""
|
|
17
16
|
|
|
18
17
|
def __init__(
|
|
@@ -69,7 +68,7 @@ class MHIST(base.ImageClassification):
|
|
|
69
68
|
)
|
|
70
69
|
|
|
71
70
|
@override
|
|
72
|
-
def
|
|
71
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
73
72
|
image_filename, _ = self._samples[index]
|
|
74
73
|
image_path = os.path.join(self._dataset_path, image_filename)
|
|
75
74
|
return io.read_image_as_tensor(image_path)
|
|
@@ -13,12 +13,11 @@ from torchvision.transforms.v2 import functional
|
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
15
|
from eva.core.data import splitting
|
|
16
|
-
from eva.vision.data.datasets import _validators, structs, wsi
|
|
17
|
-
from eva.vision.data.datasets.classification import base
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs, vision, wsi
|
|
18
17
|
from eva.vision.data.wsi.patching import samplers
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
class PANDA(wsi.MultiWsiDataset,
|
|
20
|
+
class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
22
21
|
"""Dataset class for PANDA images and corresponding targets."""
|
|
23
22
|
|
|
24
23
|
_train_split_ratio: float = 0.7
|
|
@@ -121,10 +120,10 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
121
120
|
|
|
122
121
|
@override
|
|
123
122
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
124
|
-
return
|
|
123
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
125
124
|
|
|
126
125
|
@override
|
|
127
|
-
def
|
|
126
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
128
127
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
129
128
|
return functional.to_image(image_array)
|
|
130
129
|
|
|
@@ -10,14 +10,13 @@ from torchvision.datasets import utils
|
|
|
10
10
|
from torchvision.transforms.v2 import functional
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
from eva.vision.data.datasets import _validators, structs
|
|
14
|
-
from eva.vision.data.datasets.classification import base
|
|
13
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
15
14
|
|
|
16
15
|
_URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
|
|
17
16
|
"""PatchCamelyon URL files templates."""
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class PatchCamelyon(
|
|
19
|
+
class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
21
20
|
"""Dataset class for PatchCamelyon images and corresponding targets."""
|
|
22
21
|
|
|
23
22
|
_train_resources: List[structs.DownloadResource] = [
|
|
@@ -127,7 +126,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
127
126
|
)
|
|
128
127
|
|
|
129
128
|
@override
|
|
130
|
-
def
|
|
129
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
131
130
|
return self._load_from_h5("x", index)
|
|
132
131
|
|
|
133
132
|
@override
|
|
@@ -10,12 +10,11 @@ import torch
|
|
|
10
10
|
from torchvision import tv_tensors
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
from eva.vision.data.datasets import _validators
|
|
14
|
-
from eva.vision.data.datasets.classification import base
|
|
13
|
+
from eva.vision.data.datasets import _validators, vision
|
|
15
14
|
from eva.vision.utils import io
|
|
16
15
|
|
|
17
16
|
|
|
18
|
-
class UniToPatho(
|
|
17
|
+
class UniToPatho(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
19
18
|
"""Dataset class for UniToPatho images and corresponding targets."""
|
|
20
19
|
|
|
21
20
|
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
@@ -109,7 +108,7 @@ class UniToPatho(base.ImageClassification):
|
|
|
109
108
|
)
|
|
110
109
|
|
|
111
110
|
@override
|
|
112
|
-
def
|
|
111
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
113
112
|
image_path = self._image_files[self._indices[index]]
|
|
114
113
|
return io.read_image_as_tensor(image_path)
|
|
115
114
|
|
|
@@ -9,12 +9,13 @@ import torch
|
|
|
9
9
|
from torchvision import tv_tensors
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
|
-
from eva.vision.data.datasets import wsi
|
|
13
|
-
from eva.vision.data.datasets.classification import base
|
|
12
|
+
from eva.vision.data.datasets import vision, wsi
|
|
14
13
|
from eva.vision.data.wsi.patching import samplers
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class WsiClassificationDataset(
|
|
16
|
+
class WsiClassificationDataset(
|
|
17
|
+
wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]
|
|
18
|
+
):
|
|
18
19
|
"""A general dataset class for whole-slide image classification using manifest files."""
|
|
19
20
|
|
|
20
21
|
default_column_mapping: Dict[str, str] = {
|
|
@@ -78,10 +79,10 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
78
79
|
|
|
79
80
|
@override
|
|
80
81
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
81
|
-
return
|
|
82
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
82
83
|
|
|
83
84
|
@override
|
|
84
|
-
def
|
|
85
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
85
86
|
return wsi.MultiWsiDataset.__getitem__(self, index)
|
|
86
87
|
|
|
87
88
|
@override
|