kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""UniToPatho dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Callable, Dict, List, Literal
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data.datasets import _validators, vision
|
|
14
|
+
from eva.vision.utils import io
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UniToPatho(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
18
|
+
"""Dataset class for UniToPatho images and corresponding targets."""
|
|
19
|
+
|
|
20
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
21
|
+
"train": 6270,
|
|
22
|
+
"val": 2399,
|
|
23
|
+
None: 8669,
|
|
24
|
+
}
|
|
25
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
26
|
+
|
|
27
|
+
_license: str = "CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)"
|
|
28
|
+
"""Dataset license."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
root: str,
|
|
33
|
+
split: Literal["train", "val"] | None = None,
|
|
34
|
+
transforms: Callable | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Initialize the dataset.
|
|
37
|
+
|
|
38
|
+
The dataset is split into train and validation by taking into account
|
|
39
|
+
the patient IDs to avoid any data leakage.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
root: Path to the root directory of the dataset.
|
|
43
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
44
|
+
transforms: A function/transform which returns a transformed
|
|
45
|
+
version of the raw data samples.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(transforms=transforms)
|
|
48
|
+
|
|
49
|
+
self._root = root
|
|
50
|
+
self._split = split
|
|
51
|
+
|
|
52
|
+
self._indices: List[int] = []
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@override
|
|
56
|
+
def classes(self) -> List[str]:
|
|
57
|
+
return ["HP", "NORM", "TA.HG", "TA.LG", "TVA.HG", "TVA.LG"]
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
@override
|
|
61
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
62
|
+
return {"HP": 0, "NORM": 1, "TA.HG": 2, "TA.LG": 3, "TVA.HG": 4, "TVA.LG": 5}
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def _dataset_path(self) -> str:
|
|
66
|
+
"""Returns the path of the image data of the dataset."""
|
|
67
|
+
return os.path.join(self._root, "800")
|
|
68
|
+
|
|
69
|
+
@functools.cached_property
|
|
70
|
+
def _image_files(self) -> List[str]:
|
|
71
|
+
"""Return the list of image files in the dataset.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
List of image file paths.
|
|
75
|
+
"""
|
|
76
|
+
files_pattern = os.path.join(self._dataset_path, "**/*.png")
|
|
77
|
+
image_files = list(glob.glob(files_pattern, recursive=True))
|
|
78
|
+
return sorted(image_files)
|
|
79
|
+
|
|
80
|
+
@functools.cached_property
|
|
81
|
+
def _manifest(self) -> pd.DataFrame:
|
|
82
|
+
"""Returns the train.csv & test.csv files as dataframe."""
|
|
83
|
+
df_train = pd.read_csv(os.path.join(self._dataset_path, "train.csv"))
|
|
84
|
+
df_val = pd.read_csv(os.path.join(self._dataset_path, "test.csv"))
|
|
85
|
+
df_train["split"], df_val["split"] = "train", "val"
|
|
86
|
+
return pd.concat([df_train, df_val], axis=0).set_index("image_id")
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def filename(self, index: int) -> str:
|
|
90
|
+
image_path = self._image_files[self._indices[index]]
|
|
91
|
+
return os.path.relpath(image_path, self._dataset_path)
|
|
92
|
+
|
|
93
|
+
@override
|
|
94
|
+
def prepare_data(self) -> None:
|
|
95
|
+
_validators.check_dataset_exists(self._root, True)
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
def configure(self) -> None:
|
|
99
|
+
self._indices = self._make_indices()
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def validate(self) -> None:
|
|
103
|
+
_validators.check_dataset_integrity(
|
|
104
|
+
self,
|
|
105
|
+
length=self._expected_dataset_lengths[self._split],
|
|
106
|
+
n_classes=6,
|
|
107
|
+
first_and_last_labels=("HP", "TVA.LG"),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
112
|
+
image_path = self._image_files[self._indices[index]]
|
|
113
|
+
return io.read_image_as_tensor(image_path)
|
|
114
|
+
|
|
115
|
+
@override
|
|
116
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
117
|
+
target = self._extract_class(self._image_files[self._indices[index]])
|
|
118
|
+
return torch.tensor(target, dtype=torch.long)
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def __len__(self) -> int:
|
|
122
|
+
return len(self._indices)
|
|
123
|
+
|
|
124
|
+
def _print_license(self) -> None:
|
|
125
|
+
"""Prints the dataset license."""
|
|
126
|
+
print(f"Dataset license: {self._license}")
|
|
127
|
+
|
|
128
|
+
def _extract_image_id(self, image_file: str) -> str:
|
|
129
|
+
"""Extracts the image_id from the file name."""
|
|
130
|
+
return os.path.basename(image_file)
|
|
131
|
+
|
|
132
|
+
def _extract_class(self, file: str) -> int:
|
|
133
|
+
image_id = self._extract_image_id(file)
|
|
134
|
+
return int(self._manifest.at[image_id, "top_label"])
|
|
135
|
+
|
|
136
|
+
def _make_indices(self) -> List[int]:
|
|
137
|
+
"""Builds the dataset indices for the specified split."""
|
|
138
|
+
train_indices = []
|
|
139
|
+
val_indices = []
|
|
140
|
+
|
|
141
|
+
for index, image_file in enumerate(self._image_files):
|
|
142
|
+
image_id = self._extract_image_id(image_file)
|
|
143
|
+
split = self._manifest.at[image_id, "split"]
|
|
144
|
+
|
|
145
|
+
if split == "train":
|
|
146
|
+
train_indices.append(index)
|
|
147
|
+
elif split == "val":
|
|
148
|
+
val_indices.append(index)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Invalid split value found: {split}")
|
|
151
|
+
|
|
152
|
+
split_indices = {
|
|
153
|
+
"train": train_indices,
|
|
154
|
+
"val": val_indices,
|
|
155
|
+
None: train_indices + val_indices,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return split_indices[self._split]
|
|
@@ -9,12 +9,13 @@ import torch
|
|
|
9
9
|
from torchvision import tv_tensors
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
|
-
from eva.vision.data.datasets import wsi
|
|
13
|
-
from eva.vision.data.datasets.classification import base
|
|
12
|
+
from eva.vision.data.datasets import vision, wsi
|
|
14
13
|
from eva.vision.data.wsi.patching import samplers
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class WsiClassificationDataset(
|
|
16
|
+
class WsiClassificationDataset(
|
|
17
|
+
wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]
|
|
18
|
+
):
|
|
18
19
|
"""A general dataset class for whole-slide image classification using manifest files."""
|
|
19
20
|
|
|
20
21
|
default_column_mapping: Dict[str, str] = {
|
|
@@ -78,10 +79,10 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
78
79
|
|
|
79
80
|
@override
|
|
80
81
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
81
|
-
return
|
|
82
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
82
83
|
|
|
83
84
|
@override
|
|
84
|
-
def
|
|
85
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
85
86
|
return wsi.MultiWsiDataset.__getitem__(self, index)
|
|
86
87
|
|
|
87
88
|
@override
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Segmentation datasets API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.datasets.segmentation.base import ImageSegmentation
|
|
4
3
|
from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
4
|
+
from eva.vision.data.datasets.segmentation.btcv import BTCV
|
|
5
5
|
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
6
|
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
7
|
from eva.vision.data.datasets.segmentation.lits import LiTS
|
|
@@ -10,8 +10,8 @@ from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
|
10
10
|
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
11
11
|
|
|
12
12
|
__all__ = [
|
|
13
|
-
"ImageSegmentation",
|
|
14
13
|
"BCSS",
|
|
14
|
+
"BTCV",
|
|
15
15
|
"CoNSeP",
|
|
16
16
|
"EmbeddingsSegmentationDataset",
|
|
17
17
|
"LiTS",
|
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
from typing import Any, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy.typing as npt
|
|
4
|
+
import torch
|
|
5
|
+
from torchvision import tv_tensors
|
|
4
6
|
|
|
7
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
5
8
|
from eva.vision.data.datasets import wsi
|
|
9
|
+
from eva.vision.utils import io
|
|
6
10
|
|
|
7
11
|
|
|
8
12
|
def get_coords_at_index(
|
|
@@ -36,3 +40,46 @@ def extract_mask_patch(
|
|
|
36
40
|
"""
|
|
37
41
|
(x, y), width, height = get_coords_at_index(dataset, index)
|
|
38
42
|
return mask[y : y + height, x : x + width]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_volume_tensor(file: str, orientation: str = "PLS") -> eva_tv_tensors.Volume:
|
|
46
|
+
"""Load a volume from NIfTI file as :class:`eva.vision.data.tv_tensors.Volume`.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
file: The path to the NIfTI file.
|
|
50
|
+
orientation: The orientation code to reorient the nifti image.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Volume tensor representing of shape `[T, C, H, W]`.
|
|
54
|
+
"""
|
|
55
|
+
nii = io.read_nifti(file, orientation=orientation)
|
|
56
|
+
array = io.nifti_to_array(nii)
|
|
57
|
+
array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
|
|
58
|
+
|
|
59
|
+
if nii.affine is None:
|
|
60
|
+
raise ValueError(f"Affine matrix is missing for {file}.")
|
|
61
|
+
affine = torch.tensor(nii.affine[:, [2, 0, 1, 3]], dtype=torch.float32)
|
|
62
|
+
|
|
63
|
+
return eva_tv_tensors.Volume(
|
|
64
|
+
array_reshaped_tchw, affine=affine, dtype=torch.float32
|
|
65
|
+
) # type: ignore
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def load_mask_tensor(
|
|
69
|
+
file: str, volume_file: str | None = None, orientation: str = "PLS"
|
|
70
|
+
) -> tv_tensors.Mask:
|
|
71
|
+
"""Load a volume from NIfTI file as :class:`torchvision.tv_tensors.Mask`.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
file: The path to the NIfTI file containing the mask.
|
|
75
|
+
volume_file: The path to the volume file used as orientation reference in case
|
|
76
|
+
the mask file is missing the pixdim array in the NIfTI header.
|
|
77
|
+
orientation: The orientation code to reorient the nifti image.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Mask tensor of shape `[T, C, H, W]`.
|
|
81
|
+
"""
|
|
82
|
+
nii = io.read_nifti(file, orientation="PLS", orientation_reference=volume_file)
|
|
83
|
+
array = io.nifti_to_array(nii)
|
|
84
|
+
array_reshaped_tchw = array[None, :, :, :].transpose(3, 0, 1, 2)
|
|
85
|
+
return tv_tensors.Mask(array_reshaped_tchw, dtype=torch.long) # type: ignore
|
|
@@ -12,13 +12,13 @@ from torchvision import tv_tensors
|
|
|
12
12
|
from torchvision.transforms.v2 import functional
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
16
|
-
from eva.vision.data.datasets.segmentation import _utils
|
|
15
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
16
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
17
17
|
from eva.vision.data.wsi.patching import samplers
|
|
18
18
|
from eva.vision.utils import io
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class BCSS(wsi.MultiWsiDataset,
|
|
21
|
+
class BCSS(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
22
22
|
"""Dataset class for BCSS semantic segmentation task.
|
|
23
23
|
|
|
24
24
|
Source: https://github.com/PathologyDataScience/BCSS
|
|
@@ -71,7 +71,6 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
71
71
|
width: Width of the patches to be extracted, in pixels.
|
|
72
72
|
height: Height of the patches to be extracted, in pixels.
|
|
73
73
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
74
|
-
backend: The backend to use for reading the whole-slide images.
|
|
75
74
|
transforms: Transforms to apply to the extracted image & mask patches.
|
|
76
75
|
"""
|
|
77
76
|
self._split = split
|
|
@@ -90,7 +89,7 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
90
89
|
overwrite_mpp=0.25,
|
|
91
90
|
backend="pil",
|
|
92
91
|
)
|
|
93
|
-
|
|
92
|
+
vision.VisionDataset.__init__(self, transforms=transforms)
|
|
94
93
|
|
|
95
94
|
@property
|
|
96
95
|
@override
|
|
@@ -129,15 +128,15 @@ class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
129
128
|
|
|
130
129
|
@override
|
|
131
130
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
132
|
-
return
|
|
131
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
133
132
|
|
|
134
133
|
@override
|
|
135
|
-
def
|
|
134
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
136
135
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
137
136
|
return functional.to_image(image_array)
|
|
138
137
|
|
|
139
138
|
@override
|
|
140
|
-
def
|
|
139
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
141
140
|
path = self._get_mask_path(index)
|
|
142
141
|
mask = io.read_image_as_array(path)
|
|
143
142
|
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""BTCV dataset."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import huggingface_hub
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.datasets import utils as data_utils
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
from eva.vision.data.datasets import _utils as _data_utils
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
16
|
+
from eva.vision.data.datasets.vision import VisionDataset
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BTCV(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
20
|
+
"""Beyond the Cranial Vault (BTCV) Abdomen dataset.
|
|
21
|
+
|
|
22
|
+
The BTCV dataset comprises abdominal CT acquired at the Vanderbilt
|
|
23
|
+
University Medical Center from metastatic liver cancer patients or
|
|
24
|
+
post-operative ventral hernia patients. The dataset contains one
|
|
25
|
+
background class and thirteen organ classes.
|
|
26
|
+
|
|
27
|
+
More info:
|
|
28
|
+
- Multi-organ Abdominal CT Reference Standard Segmentations
|
|
29
|
+
https://zenodo.org/records/1169361
|
|
30
|
+
- Dataset Split
|
|
31
|
+
https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/BTCV/dataset/dataset_0.json
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
_split_index_ranges = {
|
|
35
|
+
"train": [(0, 24)],
|
|
36
|
+
"val": [(24, 30)],
|
|
37
|
+
None: [(0, 30)],
|
|
38
|
+
}
|
|
39
|
+
"""Sample indices for the dataset splits."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
root: str,
|
|
44
|
+
split: Literal["train", "val"] | None = None,
|
|
45
|
+
download: bool = False,
|
|
46
|
+
transforms: Callable | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initializes the dataset.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
root: Path to the dataset root directory.
|
|
52
|
+
split: Dataset split to use ('train' or 'val').
|
|
53
|
+
If None, it uses the full dataset.
|
|
54
|
+
download: Whether to download the dataset.
|
|
55
|
+
transforms: A callable object for applying data transformations.
|
|
56
|
+
If None, no transformations are applied.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(transforms=transforms)
|
|
59
|
+
|
|
60
|
+
self._root = root
|
|
61
|
+
self._split = split
|
|
62
|
+
self._download = download
|
|
63
|
+
|
|
64
|
+
self._samples: List[Tuple[str, str]]
|
|
65
|
+
self._indices: List[int]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
@override
|
|
69
|
+
def classes(self) -> List[str]:
|
|
70
|
+
return [
|
|
71
|
+
"background",
|
|
72
|
+
"spleen",
|
|
73
|
+
"right_kidney",
|
|
74
|
+
"left_kidney",
|
|
75
|
+
"gallbladder",
|
|
76
|
+
"esophagus",
|
|
77
|
+
"liver",
|
|
78
|
+
"stomach",
|
|
79
|
+
"aorta",
|
|
80
|
+
"inferior_vena_cava",
|
|
81
|
+
"portal_and_splenic_vein",
|
|
82
|
+
"pancreas",
|
|
83
|
+
"right_adrenal_gland",
|
|
84
|
+
"left_adrenal_gland",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
@override
|
|
89
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
90
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def filename(self, index: int) -> str:
|
|
94
|
+
return os.path.basename(self._samples[self._indices[index]][0])
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def prepare_data(self) -> None:
|
|
98
|
+
if self._download:
|
|
99
|
+
self._download_dataset()
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def configure(self) -> None:
|
|
103
|
+
self._samples = self._find_samples()
|
|
104
|
+
self._indices = self._make_indices()
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
def validate(self) -> None:
|
|
108
|
+
def _valid_sample(index: int) -> bool:
|
|
109
|
+
"""Indicates if the sample files exist and are reachable."""
|
|
110
|
+
volume_file, segmentation_file = self._samples[self._indices[index]]
|
|
111
|
+
return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
|
|
112
|
+
|
|
113
|
+
if len(self._samples) < len(self._indices):
|
|
114
|
+
raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
|
|
115
|
+
|
|
116
|
+
invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
|
|
117
|
+
if invalid_samples:
|
|
118
|
+
raise OSError(
|
|
119
|
+
f"Dataset '{self.__class__.__qualname__}' contains missing or "
|
|
120
|
+
f"corrupted samples ({len(invalid_samples)} in total). "
|
|
121
|
+
f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@override
|
|
125
|
+
def __getitem__(
|
|
126
|
+
self, index: int
|
|
127
|
+
) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
|
|
128
|
+
volume = self.load_data(index)
|
|
129
|
+
mask = self.load_target(index)
|
|
130
|
+
metadata = self.load_metadata(index) or {}
|
|
131
|
+
volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
|
|
132
|
+
return volume_tensor, mask_tensor, metadata
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def __len__(self) -> int:
|
|
136
|
+
return len(self._indices)
|
|
137
|
+
|
|
138
|
+
@override
|
|
139
|
+
def load_data(self, index: int) -> eva_tv_tensors.Volume:
|
|
140
|
+
"""Loads the CT volume for a given sample.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
index: The index of the desired sample.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tensor representing the CT volume of shape `[T, C, H, W]`.
|
|
147
|
+
"""
|
|
148
|
+
ct_scan_file, _ = self._samples[self._indices[index]]
|
|
149
|
+
return _utils.load_volume_tensor(ct_scan_file)
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
153
|
+
"""Loads the segmentation mask for a given sample.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
index: The index of the desired sample.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tensor representing the segmentation mask of shape `[T, C, H, W]`.
|
|
160
|
+
"""
|
|
161
|
+
ct_scan_file, mask_file = self._samples[self._indices[index]]
|
|
162
|
+
return _utils.load_mask_tensor(mask_file, ct_scan_file)
|
|
163
|
+
|
|
164
|
+
def _apply_transforms(
|
|
165
|
+
self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
|
|
166
|
+
) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
|
|
167
|
+
"""Applies transformations to the provided data.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
ct_scan: The CT volume tensor.
|
|
171
|
+
mask: The segmentation mask tensor.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A tuple containing the transformed CT and mask tensors.
|
|
175
|
+
"""
|
|
176
|
+
return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
|
|
177
|
+
|
|
178
|
+
def _find_samples(self) -> list[tuple[str, str]]:
|
|
179
|
+
"""Retrieves the file paths for the CT volumes and segmentation.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The a list of file path to the CT volumes and segmentation.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def filename_id(filename: str) -> int:
|
|
186
|
+
matches = re.match(r".*(?:\D|^)(\d+)", filename)
|
|
187
|
+
if matches is None:
|
|
188
|
+
raise ValueError(f"Filename '{filename}' is not valid.")
|
|
189
|
+
|
|
190
|
+
return int(matches.group(1))
|
|
191
|
+
|
|
192
|
+
subdir = os.path.join(self._root, "BTCV")
|
|
193
|
+
root = subdir if os.path.isdir(subdir) else self._root
|
|
194
|
+
|
|
195
|
+
volume_files_pattern = os.path.join(root, "imagesTr", "*.nii.gz")
|
|
196
|
+
volume_filenames = glob.glob(volume_files_pattern)
|
|
197
|
+
volume_ids = {filename_id(filename): filename for filename in volume_filenames}
|
|
198
|
+
|
|
199
|
+
segmentation_files_pattern = os.path.join(root, "labelsTr", "*.nii.gz")
|
|
200
|
+
segmentation_filenames = glob.glob(segmentation_files_pattern)
|
|
201
|
+
segmentation_ids = {filename_id(filename): filename for filename in segmentation_filenames}
|
|
202
|
+
|
|
203
|
+
return [
|
|
204
|
+
(volume_ids[file_id], segmentation_ids[file_id])
|
|
205
|
+
for file_id in sorted(volume_ids.keys() & segmentation_ids.keys())
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
def _make_indices(self) -> list[int]:
|
|
209
|
+
"""Builds the dataset indices for the specified split."""
|
|
210
|
+
index_ranges = self._split_index_ranges.get(self._split)
|
|
211
|
+
if index_ranges is None:
|
|
212
|
+
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
|
|
213
|
+
|
|
214
|
+
return _data_utils.ranges_to_indices(index_ranges)
|
|
215
|
+
|
|
216
|
+
def _download_dataset(self) -> None:
|
|
217
|
+
hf_token = os.getenv("HF_TOKEN")
|
|
218
|
+
if not hf_token:
|
|
219
|
+
raise ValueError("Huggingface token required, please set the HF_TOKEN env variable.")
|
|
220
|
+
|
|
221
|
+
huggingface_hub.snapshot_download(
|
|
222
|
+
"Luffy503/VoCo_Downstream",
|
|
223
|
+
repo_type="dataset",
|
|
224
|
+
token=hf_token,
|
|
225
|
+
local_dir=self._root,
|
|
226
|
+
ignore_patterns=[".git*"],
|
|
227
|
+
allow_patterns=["BTCV.zip"],
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
zip_path = os.path.join(self._root, "BTCV.zip")
|
|
231
|
+
if not os.path.exists(zip_path):
|
|
232
|
+
raise FileNotFoundError(
|
|
233
|
+
f"BTCV.zip not found in {self._root}, something with the download went wrong."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
data_utils.extract_archive(zip_path, self._root, remove_finished=True)
|
|
@@ -11,13 +11,13 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
-
from eva.vision.data.datasets.segmentation import _utils
|
|
14
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
16
16
|
from eva.vision.data.wsi.patching import samplers
|
|
17
17
|
from eva.vision.utils import io
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class CoNSeP(wsi.MultiWsiDataset,
|
|
20
|
+
class CoNSeP(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
21
21
|
"""Dataset class for CoNSeP semantic segmentation task.
|
|
22
22
|
|
|
23
23
|
As in [1], we combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
|
|
@@ -55,7 +55,6 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
55
55
|
width: Width of the patches to be extracted, in pixels.
|
|
56
56
|
height: Height of the patches to be extracted, in pixels.
|
|
57
57
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
58
|
-
backend: The backend to use for reading the whole-slide images.
|
|
59
58
|
transforms: Transforms to apply to the extracted image & mask patches.
|
|
60
59
|
"""
|
|
61
60
|
self._split = split
|
|
@@ -112,15 +111,15 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
112
111
|
|
|
113
112
|
@override
|
|
114
113
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
115
|
-
return
|
|
114
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
116
115
|
|
|
117
116
|
@override
|
|
118
|
-
def
|
|
117
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
119
118
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
120
119
|
return functional.to_image(image_array)
|
|
121
120
|
|
|
122
121
|
@override
|
|
123
|
-
def
|
|
122
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
124
123
|
path = self._get_mask_path(index)
|
|
125
124
|
mask = np.array(io.read_mat(path)["type_map"])
|
|
126
125
|
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
@@ -14,7 +14,7 @@ class EmbeddingsSegmentationDataset(embeddings_base.EmbeddingsDataset[tv_tensors
|
|
|
14
14
|
"""Embeddings segmentation dataset."""
|
|
15
15
|
|
|
16
16
|
@override
|
|
17
|
-
def
|
|
17
|
+
def load_embeddings(self, index: int) -> List[torch.Tensor]:
|
|
18
18
|
filename = self.filename(index)
|
|
19
19
|
embeddings_path = os.path.join(self._root, filename)
|
|
20
20
|
embeddings = torch.load(embeddings_path, map_location="cpu")
|
|
@@ -23,7 +23,7 @@ class EmbeddingsSegmentationDataset(embeddings_base.EmbeddingsDataset[tv_tensors
|
|
|
23
23
|
return [tensor.squeeze(0) for tensor in embeddings]
|
|
24
24
|
|
|
25
25
|
@override
|
|
26
|
-
def
|
|
26
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
27
27
|
filename = self._data.at[index, self._column_mapping["target"]]
|
|
28
28
|
mask_path = os.path.join(self._root, filename)
|
|
29
29
|
semantic_labels = torch.load(mask_path, map_location="cpu")
|
|
@@ -13,12 +13,11 @@ from typing_extensions import override
|
|
|
13
13
|
|
|
14
14
|
from eva.core import utils
|
|
15
15
|
from eva.core.data import splitting
|
|
16
|
-
from eva.vision.data.datasets import _validators
|
|
17
|
-
from eva.vision.data.datasets.segmentation import base
|
|
16
|
+
from eva.vision.data.datasets import _validators, vision
|
|
18
17
|
from eva.vision.utils import io
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
class LiTS(
|
|
20
|
+
class LiTS(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
22
21
|
"""LiTS - Liver Tumor Segmentation Challenge.
|
|
23
22
|
|
|
24
23
|
Webpage: https://competitions.codalab.org/competitions/17094
|
|
@@ -110,21 +109,23 @@ class LiTS(base.ImageSegmentation):
|
|
|
110
109
|
)
|
|
111
110
|
|
|
112
111
|
@override
|
|
113
|
-
def
|
|
112
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
114
113
|
sample_index, slice_index = self._indices[index]
|
|
115
114
|
volume_path = self._volume_files[sample_index]
|
|
116
|
-
|
|
115
|
+
image_nii = io.read_nifti(volume_path, slice_index)
|
|
116
|
+
image_array = io.nifti_to_array(image_nii)
|
|
117
117
|
if self._fix_orientation:
|
|
118
118
|
image_array = self._orientation(image_array, sample_index)
|
|
119
119
|
return tv_tensors.Image(image_array.transpose(2, 0, 1))
|
|
120
120
|
|
|
121
121
|
@override
|
|
122
|
-
def
|
|
122
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
123
123
|
sample_index, slice_index = self._indices[index]
|
|
124
124
|
segmentation_path = self._segmentation_file(sample_index)
|
|
125
|
-
|
|
125
|
+
mask_nii = io.read_nifti(segmentation_path, slice_index)
|
|
126
|
+
mask_array = io.nifti_to_array(mask_nii)
|
|
126
127
|
if self._fix_orientation:
|
|
127
|
-
semantic_labels = self._orientation(
|
|
128
|
+
semantic_labels = self._orientation(mask_array, sample_index)
|
|
128
129
|
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
129
130
|
|
|
130
131
|
def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
|
|
@@ -64,7 +64,8 @@ class LiTSBalanced(lits.LiTS):
|
|
|
64
64
|
if sample_idx not in split_indices:
|
|
65
65
|
continue
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
segmentation_nii = io.read_nifti(self._segmentation_file(sample_idx))
|
|
68
|
+
segmentation = io.nifti_to_array(segmentation_nii)
|
|
68
69
|
tumor_filter = segmentation == 2
|
|
69
70
|
tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
|
|
70
71
|
|