kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""WSI classification dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Callable, Dict, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data.datasets import wsi
|
|
13
|
+
from eva.vision.data.datasets.classification import base
|
|
14
|
+
from eva.vision.data.wsi.patching import samplers
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
18
|
+
"""A general dataset class for whole-slide image classification using manifest files."""
|
|
19
|
+
|
|
20
|
+
default_column_mapping: Dict[str, str] = {
|
|
21
|
+
"path": "path",
|
|
22
|
+
"target": "target",
|
|
23
|
+
"split": "split",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
root: str,
|
|
29
|
+
manifest_file: str,
|
|
30
|
+
width: int,
|
|
31
|
+
height: int,
|
|
32
|
+
target_mpp: float,
|
|
33
|
+
sampler: samplers.Sampler,
|
|
34
|
+
backend: str = "openslide",
|
|
35
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
36
|
+
image_transforms: Callable | None = None,
|
|
37
|
+
column_mapping: Dict[str, str] = default_column_mapping,
|
|
38
|
+
):
|
|
39
|
+
"""Initializes the dataset.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
root: Root directory of the dataset.
|
|
43
|
+
manifest_file: The path to the manifest file, relative to
|
|
44
|
+
the `root` argument. The `path` column is expected to contain
|
|
45
|
+
relative paths to the whole-slide images.
|
|
46
|
+
width: Width of the patches to be extracted, in pixels.
|
|
47
|
+
height: Height of the patches to be extracted, in pixels.
|
|
48
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
49
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
50
|
+
backend: The backend to use for reading the whole-slide images.
|
|
51
|
+
split: The split of the dataset to load.
|
|
52
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
53
|
+
column_mapping: Mapping of the columns in the manifest file.
|
|
54
|
+
"""
|
|
55
|
+
self._split = split
|
|
56
|
+
self._column_mapping = self.default_column_mapping | column_mapping
|
|
57
|
+
self._manifest = self._load_manifest(os.path.join(root, manifest_file))
|
|
58
|
+
|
|
59
|
+
wsi.MultiWsiDataset.__init__(
|
|
60
|
+
self,
|
|
61
|
+
root=root,
|
|
62
|
+
file_paths=self._manifest[self._column_mapping["path"]].tolist(),
|
|
63
|
+
width=width,
|
|
64
|
+
height=height,
|
|
65
|
+
sampler=sampler,
|
|
66
|
+
target_mpp=target_mpp,
|
|
67
|
+
backend=backend,
|
|
68
|
+
image_transforms=image_transforms,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def filename(self, index: int) -> str:
|
|
73
|
+
path = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["path"]]
|
|
74
|
+
return os.path.basename(path) if os.path.isabs(path) else path
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
78
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
82
|
+
return wsi.MultiWsiDataset.__getitem__(self, index)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
86
|
+
target = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["target"]]
|
|
87
|
+
return np.asarray(target)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
91
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
92
|
+
|
|
93
|
+
def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
|
|
94
|
+
df = pd.read_csv(manifest_path)
|
|
95
|
+
|
|
96
|
+
missing_columns = set(self._column_mapping.values()) - set(df.columns)
|
|
97
|
+
if self._split is None:
|
|
98
|
+
missing_columns = missing_columns - {self._column_mapping["split"]}
|
|
99
|
+
if missing_columns:
|
|
100
|
+
raise ValueError(f"Missing columns in the manifest file: {missing_columns}")
|
|
101
|
+
|
|
102
|
+
if self._split is not None:
|
|
103
|
+
df = df.loc[df[self._column_mapping["split"]] == self._split]
|
|
104
|
+
|
|
105
|
+
return df.reset_index(drop=True)
|
|
@@ -1,6 +1,19 @@
|
|
|
1
1
|
"""Segmentation datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.segmentation.base import ImageSegmentation
|
|
4
|
-
from eva.vision.data.datasets.segmentation.
|
|
4
|
+
from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
5
|
+
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
|
+
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
|
+
from eva.vision.data.datasets.segmentation.lits import LiTS
|
|
8
|
+
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
9
|
+
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
5
10
|
|
|
6
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ImageSegmentation",
|
|
13
|
+
"BCSS",
|
|
14
|
+
"CoNSeP",
|
|
15
|
+
"EmbeddingsSegmentationDataset",
|
|
16
|
+
"LiTS",
|
|
17
|
+
"MoNuSAC",
|
|
18
|
+
"TotalSegmentator2D",
|
|
19
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.datasets import wsi
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_coords_at_index(
|
|
9
|
+
dataset: wsi.MultiWsiDataset, index: int
|
|
10
|
+
) -> Tuple[Tuple[int, int], int, int]:
|
|
11
|
+
"""Returns the coordinates ((x,y),width,height) of the patch at the given index.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
dataset: The WSI dataset instance.
|
|
15
|
+
index: The sample index.
|
|
16
|
+
"""
|
|
17
|
+
image_index = dataset._get_dataset_idx(index)
|
|
18
|
+
patch_index = index if image_index == 0 else index - dataset.cumulative_sizes[image_index - 1]
|
|
19
|
+
wsi_dataset = dataset.datasets[image_index]
|
|
20
|
+
if isinstance(wsi_dataset, wsi.WsiDataset):
|
|
21
|
+
coords = wsi_dataset._coords
|
|
22
|
+
return coords.x_y[patch_index], coords.width, coords.height
|
|
23
|
+
else:
|
|
24
|
+
raise Exception(f"Expected WsiDataset, got {type(wsi_dataset)}")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def extract_mask_patch(
|
|
28
|
+
mask: npt.NDArray[Any], dataset: wsi.MultiWsiDataset, index: int
|
|
29
|
+
) -> npt.NDArray[Any]:
|
|
30
|
+
"""Reads the mask patch at the coordinates corresponding to the dataset index.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
mask: The mask array.
|
|
34
|
+
dataset: The WSI dataset instance.
|
|
35
|
+
index: The sample index.
|
|
36
|
+
"""
|
|
37
|
+
(x, y), width, height = get_coords_at_index(dataset, index)
|
|
38
|
+
return mask[y : y + height, x : x + width]
|
|
@@ -12,10 +12,7 @@ from eva.vision.data.datasets import vision
|
|
|
12
12
|
class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
|
|
13
13
|
"""Image segmentation abstract dataset."""
|
|
14
14
|
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
transforms: Callable | None = None,
|
|
18
|
-
) -> None:
|
|
15
|
+
def __init__(self, transforms: Callable | None = None) -> None:
|
|
19
16
|
"""Initializes the image segmentation base class.
|
|
20
17
|
|
|
21
18
|
Args:
|
|
@@ -34,17 +31,6 @@ class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.
|
|
|
34
31
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
35
32
|
"""Returns a mapping of the class name to its target index."""
|
|
36
33
|
|
|
37
|
-
def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
|
|
38
|
-
"""Returns the dataset metadata.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
index: The index of the data sample to return the metadata of.
|
|
42
|
-
If `None`, it will return the metadata of the current dataset.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The sample metadata.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
34
|
@abc.abstractmethod
|
|
49
35
|
def load_image(self, index: int) -> tv_tensors.Image:
|
|
50
36
|
"""Loads and returns the `index`'th image sample.
|
|
@@ -68,16 +54,29 @@ class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.
|
|
|
68
54
|
values which represent the pixel class id.
|
|
69
55
|
"""
|
|
70
56
|
|
|
57
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
58
|
+
"""Returns the dataset metadata.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
index: The index of the data sample to return the metadata of.
|
|
62
|
+
If `None`, it will return the metadata of the current dataset.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The sample metadata.
|
|
66
|
+
"""
|
|
67
|
+
|
|
71
68
|
@abc.abstractmethod
|
|
72
69
|
@override
|
|
73
70
|
def __len__(self) -> int:
|
|
74
71
|
raise NotImplementedError
|
|
75
72
|
|
|
76
73
|
@override
|
|
77
|
-
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
|
|
74
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
78
75
|
image = self.load_image(index)
|
|
79
76
|
mask = self.load_mask(index)
|
|
80
|
-
|
|
77
|
+
metadata = self.load_metadata(index) or {}
|
|
78
|
+
image_tensor, mask_tensor = self._apply_transforms(image, mask)
|
|
79
|
+
return image_tensor, mask_tensor, metadata
|
|
81
80
|
|
|
82
81
|
def _apply_transforms(
|
|
83
82
|
self, image: tv_tensors.Image, mask: tv_tensors.Mask
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""BCSS dataset."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
import torch
|
|
11
|
+
from torchvision import tv_tensors
|
|
12
|
+
from torchvision.transforms.v2 import functional
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.vision.data.datasets import _validators, wsi
|
|
16
|
+
from eva.vision.data.datasets.segmentation import _utils, base
|
|
17
|
+
from eva.vision.data.wsi.patching import samplers
|
|
18
|
+
from eva.vision.utils import io
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BCSS(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
22
|
+
"""Dataset class for BCSS semantic segmentation task.
|
|
23
|
+
|
|
24
|
+
Source: https://github.com/PathologyDataScience/BCSS
|
|
25
|
+
|
|
26
|
+
We apply the the class grouping proposed by the challenge baseline:
|
|
27
|
+
https://bcsegmentation.grand-challenge.org/Baseline/
|
|
28
|
+
|
|
29
|
+
outside_roi: outside_roi
|
|
30
|
+
tumor: angioinvasion, dcis
|
|
31
|
+
stroma: stroma
|
|
32
|
+
inflammatory: lymphocytic_infiltrate, plasma_cells, other_immune_infiltrate
|
|
33
|
+
necrosis: necrosis_or_debris
|
|
34
|
+
other: remaining
|
|
35
|
+
|
|
36
|
+
Be aware that outside_roi should be assigned zero-weight during model training.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
_train_split_ratio: float = 0.8
|
|
40
|
+
"""Train split ratio."""
|
|
41
|
+
|
|
42
|
+
_val_split_ratio: float = 0.2
|
|
43
|
+
"""Validation split ratio."""
|
|
44
|
+
|
|
45
|
+
_expected_length: int = 151
|
|
46
|
+
"""Expected dataset length."""
|
|
47
|
+
|
|
48
|
+
_val_institutes = {"BH", "C8", "A8", "A1", "E9"}
|
|
49
|
+
"""Medical institutes to use for the validation split."""
|
|
50
|
+
|
|
51
|
+
_test_institutes = {"OL", "LL", "E2", "EW", "GM", "S3"}
|
|
52
|
+
"""Medical institutes to use for the test split."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
root: str,
|
|
57
|
+
sampler: samplers.Sampler,
|
|
58
|
+
split: Literal["train", "val", "trainval", "test"] | None = None,
|
|
59
|
+
width: int = 224,
|
|
60
|
+
height: int = 224,
|
|
61
|
+
target_mpp: float = 0.5,
|
|
62
|
+
transforms: Callable | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Initializes the dataset.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
root: Root directory of the dataset.
|
|
68
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
69
|
+
If `None`, it will use the ::class::`GridSampler` sampler.
|
|
70
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
71
|
+
width: Width of the patches to be extracted, in pixels.
|
|
72
|
+
height: Height of the patches to be extracted, in pixels.
|
|
73
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
74
|
+
backend: The backend to use for reading the whole-slide images.
|
|
75
|
+
transforms: Transforms to apply to the extracted image & mask patches.
|
|
76
|
+
"""
|
|
77
|
+
self._split = split
|
|
78
|
+
self._root = root
|
|
79
|
+
|
|
80
|
+
self.datasets: List[wsi.WsiDataset] # type: ignore
|
|
81
|
+
|
|
82
|
+
wsi.MultiWsiDataset.__init__(
|
|
83
|
+
self,
|
|
84
|
+
root=root,
|
|
85
|
+
file_paths=self._load_file_paths(split),
|
|
86
|
+
width=width,
|
|
87
|
+
height=height,
|
|
88
|
+
sampler=sampler or samplers.GridSampler(max_samples=1000),
|
|
89
|
+
target_mpp=target_mpp,
|
|
90
|
+
overwrite_mpp=0.25,
|
|
91
|
+
backend="pil",
|
|
92
|
+
)
|
|
93
|
+
base.ImageSegmentation.__init__(self, transforms=transforms)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
@override
|
|
97
|
+
def classes(self) -> List[str]:
|
|
98
|
+
return list(self.class_to_idx.keys())
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
@override
|
|
102
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
103
|
+
return {
|
|
104
|
+
"outside_roi": 0,
|
|
105
|
+
"tumor": 1,
|
|
106
|
+
"stroma": 2,
|
|
107
|
+
"inflammatory": 3,
|
|
108
|
+
"necrosis": 4,
|
|
109
|
+
"other": 5,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def prepare_data(self) -> None:
|
|
114
|
+
_validators.check_dataset_exists(self._root, True)
|
|
115
|
+
|
|
116
|
+
if not os.path.isdir(os.path.join(self._root, "masks")):
|
|
117
|
+
raise FileNotFoundError(f"'masks' directory not found in {self._root}.")
|
|
118
|
+
if not os.path.isdir(os.path.join(self._root, "rgbs_colorNormalized")):
|
|
119
|
+
raise FileNotFoundError(f"'rgbs_colorNormalized' directory not found in {self._root}.")
|
|
120
|
+
|
|
121
|
+
@override
|
|
122
|
+
def validate(self) -> None:
|
|
123
|
+
_validators.check_dataset_integrity(
|
|
124
|
+
self,
|
|
125
|
+
length=None,
|
|
126
|
+
n_classes=6,
|
|
127
|
+
first_and_last_labels=((self.classes[0], self.classes[-1])),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@override
|
|
131
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
132
|
+
return base.ImageSegmentation.__getitem__(self, index)
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
136
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
137
|
+
return functional.to_image(image_array)
|
|
138
|
+
|
|
139
|
+
@override
|
|
140
|
+
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
141
|
+
path = self._get_mask_path(index)
|
|
142
|
+
mask = io.read_image_as_array(path)
|
|
143
|
+
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
144
|
+
mask_patch = self._map_classes(mask_patch)
|
|
145
|
+
return tv_tensors.Mask(mask_patch, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
146
|
+
|
|
147
|
+
@override
|
|
148
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
149
|
+
(x, y), width, height = _utils.get_coords_at_index(self, index)
|
|
150
|
+
return {"coords": f"{x},{y},{width},{height}"}
|
|
151
|
+
|
|
152
|
+
def _load_file_paths(
|
|
153
|
+
self, split: Literal["train", "val", "trainval", "test"] | None = None
|
|
154
|
+
) -> List[str]:
|
|
155
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
156
|
+
file_paths = sorted(glob.glob(os.path.join(self._root, "rgbs_colorNormalized/*.png")))
|
|
157
|
+
if len(file_paths) != self._expected_length:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Expected {self._expected_length} images, found {len(file_paths)} in {self._root}."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
train_indices, val_indices, test_indices = [], [], []
|
|
163
|
+
for i, path in enumerate(file_paths):
|
|
164
|
+
institute = Path(path).stem.split("-")[1]
|
|
165
|
+
if institute in self._test_institutes:
|
|
166
|
+
test_indices.append(i)
|
|
167
|
+
elif institute in self._val_institutes:
|
|
168
|
+
val_indices.append(i)
|
|
169
|
+
else:
|
|
170
|
+
train_indices.append(i)
|
|
171
|
+
|
|
172
|
+
match split:
|
|
173
|
+
case "train":
|
|
174
|
+
return [file_paths[i] for i in train_indices]
|
|
175
|
+
case "val":
|
|
176
|
+
return [file_paths[i] for i in val_indices]
|
|
177
|
+
case "trainval":
|
|
178
|
+
return [file_paths[i] for i in train_indices + val_indices]
|
|
179
|
+
case "test":
|
|
180
|
+
return [file_paths[i] for i in test_indices]
|
|
181
|
+
case None:
|
|
182
|
+
return file_paths
|
|
183
|
+
case _:
|
|
184
|
+
raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
|
|
185
|
+
|
|
186
|
+
def _get_mask_path(self, index):
|
|
187
|
+
"""Returns the path to the mask file corresponding to the patch at the given index."""
|
|
188
|
+
return os.path.join(self._root, "masks", self.filename(index))
|
|
189
|
+
|
|
190
|
+
def _map_classes(self, array: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
191
|
+
"""Maps the classes of the mask array to the grouped tissue type classes."""
|
|
192
|
+
original_to_grouped_class_mapping = {
|
|
193
|
+
"outside_roi": "outside_roi",
|
|
194
|
+
"angioinvasion": "tumor",
|
|
195
|
+
"dcis": "tumor",
|
|
196
|
+
"stroma": "stroma",
|
|
197
|
+
"lymphocytic_infiltrate": "inflammatory",
|
|
198
|
+
"plasma_cells": "inflammatory",
|
|
199
|
+
"other_immune_infiltrate": "inflammatory",
|
|
200
|
+
"necrosis_or_debris": "necrosis",
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
mapped_array = np.full_like(array, fill_value=self.class_to_idx["other"], dtype=int)
|
|
204
|
+
|
|
205
|
+
for original_class, grouped_class in original_to_grouped_class_mapping.items():
|
|
206
|
+
original_class_idx = _original_class_to_idx[original_class]
|
|
207
|
+
grouped_class_idx = self.class_to_idx[grouped_class]
|
|
208
|
+
mapped_array[array == original_class_idx] = grouped_class_idx
|
|
209
|
+
|
|
210
|
+
return mapped_array
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
_original_class_to_idx = {
|
|
214
|
+
"outside_roi": 0,
|
|
215
|
+
"tumor": 1,
|
|
216
|
+
"stroma": 2,
|
|
217
|
+
"lymphocytic_infiltrate": 3,
|
|
218
|
+
"necrosis_or_debris": 4,
|
|
219
|
+
"glandular_secretions": 5,
|
|
220
|
+
"blood": 6,
|
|
221
|
+
"exclude": 7,
|
|
222
|
+
"metaplasia_NOS": 8,
|
|
223
|
+
"fat": 9,
|
|
224
|
+
"plasma_cells": 10,
|
|
225
|
+
"other_immune_infiltrate": 11,
|
|
226
|
+
"mucoid_material": 12,
|
|
227
|
+
"normal_acinus_or_duct": 13,
|
|
228
|
+
"lymphatics": 14,
|
|
229
|
+
"undetermined": 15,
|
|
230
|
+
"nerve": 16,
|
|
231
|
+
"skin_adnexa": 17,
|
|
232
|
+
"blood_vessel": 18,
|
|
233
|
+
"angioinvasion": 19,
|
|
234
|
+
"dcis": 20,
|
|
235
|
+
"other": 21,
|
|
236
|
+
}
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""CoNSeP dataset."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms.v2 import functional
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils, base
|
|
16
|
+
from eva.vision.data.wsi.patching import samplers
|
|
17
|
+
from eva.vision.utils import io
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
21
|
+
"""Dataset class for CoNSeP semantic segmentation task.
|
|
22
|
+
|
|
23
|
+
We combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
|
|
24
|
+
into the epithelial class and 5 (fibroblast), 6 (muscle) & 7 (endothelial) into
|
|
25
|
+
the spindle-shaped class.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
29
|
+
"train": 27,
|
|
30
|
+
"val": 14,
|
|
31
|
+
None: 41,
|
|
32
|
+
}
|
|
33
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
root: str,
|
|
38
|
+
sampler: samplers.Sampler | None = None,
|
|
39
|
+
split: Literal["train", "val"] | None = None,
|
|
40
|
+
width: int = 224,
|
|
41
|
+
height: int = 224,
|
|
42
|
+
target_mpp: float = 0.25,
|
|
43
|
+
transforms: Callable | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Initializes the dataset.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
root: Root directory of the dataset.
|
|
49
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
50
|
+
If `None`, it will use the ::class::`ForegroundGridSampler` sampler.
|
|
51
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
52
|
+
width: Width of the patches to be extracted, in pixels.
|
|
53
|
+
height: Height of the patches to be extracted, in pixels.
|
|
54
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
55
|
+
backend: The backend to use for reading the whole-slide images.
|
|
56
|
+
transforms: Transforms to apply to the extracted image & mask patches.
|
|
57
|
+
"""
|
|
58
|
+
self._split = split
|
|
59
|
+
self._root = root
|
|
60
|
+
|
|
61
|
+
self.datasets: List[wsi.WsiDataset] # type: ignore
|
|
62
|
+
|
|
63
|
+
wsi.MultiWsiDataset.__init__(
|
|
64
|
+
self,
|
|
65
|
+
root=root,
|
|
66
|
+
file_paths=self._load_file_paths(split),
|
|
67
|
+
width=width,
|
|
68
|
+
height=height,
|
|
69
|
+
sampler=sampler or samplers.ForegroundGridSampler(max_samples=25),
|
|
70
|
+
target_mpp=target_mpp,
|
|
71
|
+
overwrite_mpp=0.25,
|
|
72
|
+
backend="pil",
|
|
73
|
+
image_transforms=transforms,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
@override
|
|
78
|
+
def classes(self) -> List[str]:
|
|
79
|
+
return [
|
|
80
|
+
"background",
|
|
81
|
+
"other",
|
|
82
|
+
"inflammatory",
|
|
83
|
+
"epithelial",
|
|
84
|
+
"spindle-shaped",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
@override
|
|
89
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
90
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def prepare_data(self) -> None:
|
|
94
|
+
_validators.check_dataset_exists(self._root, True)
|
|
95
|
+
|
|
96
|
+
if not os.path.isdir(os.path.join(self._root, "Train")):
|
|
97
|
+
raise FileNotFoundError(f"Train directory not found in {self._root}.")
|
|
98
|
+
if not os.path.isdir(os.path.join(self._root, "Test")):
|
|
99
|
+
raise FileNotFoundError(f"Test directory not found in {self._root}.")
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def validate(self) -> None:
|
|
103
|
+
_validators.check_dataset_integrity(
|
|
104
|
+
self,
|
|
105
|
+
length=None,
|
|
106
|
+
n_classes=5,
|
|
107
|
+
first_and_last_labels=((self.classes[0], self.classes[-1])),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
112
|
+
return base.ImageSegmentation.__getitem__(self, index)
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
116
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
117
|
+
return functional.to_image(image_array)
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
121
|
+
path = self._get_mask_path(index)
|
|
122
|
+
mask = np.array(io.read_mat(path)["type_map"])
|
|
123
|
+
mask_patch = _utils.extract_mask_patch(mask, self, index)
|
|
124
|
+
mask_patch = self._map_classes(mask_patch)
|
|
125
|
+
mask_tensor = tv_tensors.Mask(mask_patch, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
126
|
+
return self._image_transforms(mask_tensor) if self._image_transforms else mask_tensor
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
130
|
+
(x, y), width, height = _utils.get_coords_at_index(self, index)
|
|
131
|
+
return {"coords": f"{x},{y},{width},{height}"}
|
|
132
|
+
|
|
133
|
+
def _load_file_paths(self, split: Literal["train", "val"] | None = None) -> List[str]:
|
|
134
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
135
|
+
paths = list(glob.glob(os.path.join(self._root, "**/Images/*.png"), recursive=True))
|
|
136
|
+
n_expected = self._expected_dataset_lengths[None]
|
|
137
|
+
if len(paths) != n_expected:
|
|
138
|
+
raise ValueError(f"Expected {n_expected} images, found {len(paths)} in {self._root}.")
|
|
139
|
+
|
|
140
|
+
if split is not None:
|
|
141
|
+
split_to_folder = {"train": "Train", "val": "Test"}
|
|
142
|
+
paths = filter(lambda p: split_to_folder[split] == p.split("/")[-3], paths)
|
|
143
|
+
|
|
144
|
+
return sorted(paths)
|
|
145
|
+
|
|
146
|
+
def _get_mask_path(self, index: int) -> str:
|
|
147
|
+
"""Returns the path to the mask file corresponding to the patch at the given index."""
|
|
148
|
+
filename = self.filename(index).split(".")[0]
|
|
149
|
+
mask_dir = "Train" if filename.startswith("train") else "Test"
|
|
150
|
+
return os.path.join(self._root, mask_dir, "Labels", f"{filename}.mat")
|
|
151
|
+
|
|
152
|
+
def _map_classes(self, array: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
153
|
+
"""Summarizes classes 3 & 4, and 5, 6."""
|
|
154
|
+
array = np.where(array == 4, 3, array)
|
|
155
|
+
array = np.where(array > 4, 4, array)
|
|
156
|
+
return array
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Embeddings based semantic segmentation dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.data.datasets import embeddings as embeddings_base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingsSegmentationDataset(embeddings_base.EmbeddingsDataset[tv_tensors.Mask]):
|
|
14
|
+
"""Embeddings segmentation dataset."""
|
|
15
|
+
|
|
16
|
+
@override
|
|
17
|
+
def _load_embeddings(self, index: int) -> List[torch.Tensor]:
|
|
18
|
+
filename = self.filename(index)
|
|
19
|
+
embeddings_path = os.path.join(self._root, filename)
|
|
20
|
+
embeddings = torch.load(embeddings_path, map_location="cpu")
|
|
21
|
+
if isinstance(embeddings, torch.Tensor):
|
|
22
|
+
embeddings = [embeddings]
|
|
23
|
+
return [tensor.squeeze(0) for tensor in embeddings]
|
|
24
|
+
|
|
25
|
+
@override
|
|
26
|
+
def _load_target(self, index: int) -> tv_tensors.Mask:
|
|
27
|
+
filename = self._data.at[index, self._column_mapping["target"]]
|
|
28
|
+
mask_path = os.path.join(self._root, filename)
|
|
29
|
+
semantic_labels = torch.load(mask_path, map_location="cpu")
|
|
30
|
+
return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
31
|
+
|
|
32
|
+
@override
|
|
33
|
+
def __len__(self) -> int:
|
|
34
|
+
return len(self._data)
|