kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.datasets import folder, utils
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -52,8 +53,7 @@ class BACH(base.ImageClassification):
|
|
|
52
53
|
root: str,
|
|
53
54
|
split: Literal["train", "val"] | None = None,
|
|
54
55
|
download: bool = False,
|
|
55
|
-
|
|
56
|
-
target_transforms: Callable | None = None,
|
|
56
|
+
transforms: Callable | None = None,
|
|
57
57
|
) -> None:
|
|
58
58
|
"""Initialize the dataset.
|
|
59
59
|
|
|
@@ -68,15 +68,10 @@ class BACH(base.ImageClassification):
|
|
|
68
68
|
Note that the download will be executed only by additionally
|
|
69
69
|
calling the :meth:`prepare_data` method and if the data does
|
|
70
70
|
not yet exist on disk.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
target_transforms: A function/transform that takes in the target
|
|
74
|
-
and transforms it.
|
|
71
|
+
transforms: A function/transform which returns a transformed
|
|
72
|
+
version of the raw data samples.
|
|
75
73
|
"""
|
|
76
|
-
super().__init__(
|
|
77
|
-
image_transforms=image_transforms,
|
|
78
|
-
target_transforms=target_transforms,
|
|
79
|
-
)
|
|
74
|
+
super().__init__(transforms=transforms)
|
|
80
75
|
|
|
81
76
|
self._root = root
|
|
82
77
|
self._split = split
|
|
@@ -130,14 +125,14 @@ class BACH(base.ImageClassification):
|
|
|
130
125
|
)
|
|
131
126
|
|
|
132
127
|
@override
|
|
133
|
-
def load_image(self, index: int) ->
|
|
128
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
134
129
|
image_path, _ = self._samples[self._indices[index]]
|
|
135
|
-
return io.
|
|
130
|
+
return io.read_image_as_tensor(image_path)
|
|
136
131
|
|
|
137
132
|
@override
|
|
138
|
-
def load_target(self, index: int) ->
|
|
133
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
139
134
|
_, target = self._samples[self._indices[index]]
|
|
140
|
-
return
|
|
135
|
+
return torch.tensor(target, dtype=torch.long)
|
|
141
136
|
|
|
142
137
|
@override
|
|
143
138
|
def __len__(self) -> int:
|
|
@@ -3,32 +3,29 @@
|
|
|
3
3
|
import abc
|
|
4
4
|
from typing import Any, Callable, Dict, List, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.vision.data.datasets import vision
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
class ImageClassification(vision.VisionDataset[Tuple[
|
|
13
|
+
class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
|
|
13
14
|
"""Image classification abstract dataset."""
|
|
14
15
|
|
|
15
16
|
def __init__(
|
|
16
17
|
self,
|
|
17
|
-
|
|
18
|
-
target_transforms: Callable | None = None,
|
|
18
|
+
transforms: Callable | None = None,
|
|
19
19
|
) -> None:
|
|
20
20
|
"""Initializes the image classification dataset.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
target_transforms: A function/transform that takes in the target
|
|
26
|
-
and transforms it.
|
|
23
|
+
transforms: A function/transform which returns a transformed
|
|
24
|
+
version of the raw data samples.
|
|
27
25
|
"""
|
|
28
26
|
super().__init__()
|
|
29
27
|
|
|
30
|
-
self.
|
|
31
|
-
self._target_transforms = target_transforms
|
|
28
|
+
self._transforms = transforms
|
|
32
29
|
|
|
33
30
|
@property
|
|
34
31
|
def classes(self) -> List[str] | None:
|
|
@@ -38,19 +35,18 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
38
35
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
39
36
|
"""Returns a mapping of the class name to its target index."""
|
|
40
37
|
|
|
41
|
-
def load_metadata(self, index: int
|
|
38
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
42
39
|
"""Returns the dataset metadata.
|
|
43
40
|
|
|
44
41
|
Args:
|
|
45
42
|
index: The index of the data sample to return the metadata of.
|
|
46
|
-
If `None`, it will return the metadata of the current dataset.
|
|
47
43
|
|
|
48
44
|
Returns:
|
|
49
45
|
The sample metadata.
|
|
50
46
|
"""
|
|
51
47
|
|
|
52
48
|
@abc.abstractmethod
|
|
53
|
-
def load_image(self, index: int) ->
|
|
49
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
54
50
|
"""Returns the `index`'th image sample.
|
|
55
51
|
|
|
56
52
|
Args:
|
|
@@ -61,7 +57,7 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
61
57
|
"""
|
|
62
58
|
|
|
63
59
|
@abc.abstractmethod
|
|
64
|
-
def load_target(self, index: int) ->
|
|
60
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
65
61
|
"""Returns the `index`'th target sample.
|
|
66
62
|
|
|
67
63
|
Args:
|
|
@@ -77,14 +73,15 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
77
73
|
raise NotImplementedError
|
|
78
74
|
|
|
79
75
|
@override
|
|
80
|
-
def __getitem__(self, index: int) -> Tuple[
|
|
76
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
81
77
|
image = self.load_image(index)
|
|
82
78
|
target = self.load_target(index)
|
|
83
|
-
|
|
79
|
+
image, target = self._apply_transforms(image, target)
|
|
80
|
+
return image, target, self.load_metadata(index) or {}
|
|
84
81
|
|
|
85
82
|
def _apply_transforms(
|
|
86
|
-
self, image:
|
|
87
|
-
) -> Tuple[
|
|
83
|
+
self, image: tv_tensors.Image, target: torch.Tensor
|
|
84
|
+
) -> Tuple[tv_tensors.Image, torch.Tensor]:
|
|
88
85
|
"""Applies the transforms to the provided data and returns them.
|
|
89
86
|
|
|
90
87
|
Args:
|
|
@@ -94,10 +91,6 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
94
91
|
Returns:
|
|
95
92
|
A tuple with the image and the target transformed.
|
|
96
93
|
"""
|
|
97
|
-
if self.
|
|
98
|
-
image = self.
|
|
99
|
-
|
|
100
|
-
if self._target_transforms is not None:
|
|
101
|
-
target = self._target_transforms(target)
|
|
102
|
-
|
|
94
|
+
if self._transforms is not None:
|
|
95
|
+
image, target = self._transforms(image, target)
|
|
103
96
|
return image, target
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Camelyon16 dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms.v2 import functional
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
+
from eva.vision.data.datasets.classification import base
|
|
16
|
+
from eva.vision.data.wsi.patching import samplers
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
20
|
+
"""Dataset class for Camelyon16 images and corresponding targets."""
|
|
21
|
+
|
|
22
|
+
_val_slides = [
|
|
23
|
+
"normal_010",
|
|
24
|
+
"normal_013",
|
|
25
|
+
"normal_016",
|
|
26
|
+
"normal_017",
|
|
27
|
+
"normal_019",
|
|
28
|
+
"normal_020",
|
|
29
|
+
"normal_025",
|
|
30
|
+
"normal_030",
|
|
31
|
+
"normal_031",
|
|
32
|
+
"normal_032",
|
|
33
|
+
"normal_052",
|
|
34
|
+
"normal_056",
|
|
35
|
+
"normal_057",
|
|
36
|
+
"normal_067",
|
|
37
|
+
"normal_076",
|
|
38
|
+
"normal_079",
|
|
39
|
+
"normal_085",
|
|
40
|
+
"normal_095",
|
|
41
|
+
"normal_098",
|
|
42
|
+
"normal_099",
|
|
43
|
+
"normal_101",
|
|
44
|
+
"normal_102",
|
|
45
|
+
"normal_105",
|
|
46
|
+
"normal_106",
|
|
47
|
+
"normal_109",
|
|
48
|
+
"normal_129",
|
|
49
|
+
"normal_132",
|
|
50
|
+
"normal_137",
|
|
51
|
+
"normal_142",
|
|
52
|
+
"normal_143",
|
|
53
|
+
"normal_148",
|
|
54
|
+
"normal_152",
|
|
55
|
+
"tumor_001",
|
|
56
|
+
"tumor_005",
|
|
57
|
+
"tumor_011",
|
|
58
|
+
"tumor_012",
|
|
59
|
+
"tumor_013",
|
|
60
|
+
"tumor_019",
|
|
61
|
+
"tumor_031",
|
|
62
|
+
"tumor_037",
|
|
63
|
+
"tumor_043",
|
|
64
|
+
"tumor_046",
|
|
65
|
+
"tumor_057",
|
|
66
|
+
"tumor_065",
|
|
67
|
+
"tumor_069",
|
|
68
|
+
"tumor_071",
|
|
69
|
+
"tumor_073",
|
|
70
|
+
"tumor_079",
|
|
71
|
+
"tumor_080",
|
|
72
|
+
"tumor_081",
|
|
73
|
+
"tumor_082",
|
|
74
|
+
"tumor_085",
|
|
75
|
+
"tumor_097",
|
|
76
|
+
"tumor_109",
|
|
77
|
+
]
|
|
78
|
+
"""Validation slide names, same as the ones in patch camelyon."""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
root: str,
|
|
83
|
+
sampler: samplers.Sampler,
|
|
84
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
85
|
+
width: int = 224,
|
|
86
|
+
height: int = 224,
|
|
87
|
+
target_mpp: float = 0.5,
|
|
88
|
+
backend: str = "openslide",
|
|
89
|
+
image_transforms: Callable | None = None,
|
|
90
|
+
seed: int = 42,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Initializes the dataset.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
root: Root directory of the dataset.
|
|
96
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
97
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
98
|
+
width: Width of the patches to be extracted, in pixels.
|
|
99
|
+
height: Height of the patches to be extracted, in pixels.
|
|
100
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
101
|
+
backend: The backend to use for reading the whole-slide images.
|
|
102
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
103
|
+
seed: Random seed for reproducibility.
|
|
104
|
+
"""
|
|
105
|
+
self._split = split
|
|
106
|
+
self._root = root
|
|
107
|
+
self._width = width
|
|
108
|
+
self._height = height
|
|
109
|
+
self._target_mpp = target_mpp
|
|
110
|
+
self._seed = seed
|
|
111
|
+
|
|
112
|
+
wsi.MultiWsiDataset.__init__(
|
|
113
|
+
self,
|
|
114
|
+
root=root,
|
|
115
|
+
file_paths=self._load_file_paths(split),
|
|
116
|
+
width=width,
|
|
117
|
+
height=height,
|
|
118
|
+
sampler=sampler,
|
|
119
|
+
target_mpp=target_mpp,
|
|
120
|
+
backend=backend,
|
|
121
|
+
image_transforms=image_transforms,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
@override
|
|
126
|
+
def classes(self) -> List[str]:
|
|
127
|
+
return ["normal", "tumor"]
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
@override
|
|
131
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
132
|
+
return {"normal": 0, "tumor": 1}
|
|
133
|
+
|
|
134
|
+
@functools.cached_property
|
|
135
|
+
def annotations_test_set(self) -> Dict[str, str]:
|
|
136
|
+
"""Loads the dataset labels."""
|
|
137
|
+
path = os.path.join(self._root, "testing/reference.csv")
|
|
138
|
+
reference_df = pd.read_csv(path, header=None)
|
|
139
|
+
return {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
|
|
140
|
+
|
|
141
|
+
@functools.cached_property
|
|
142
|
+
def annotations(self) -> Dict[str, str]:
|
|
143
|
+
"""Loads the dataset labels."""
|
|
144
|
+
annotations = {}
|
|
145
|
+
if self._split in ["test", None]:
|
|
146
|
+
path = os.path.join(self._root, "testing/reference.csv")
|
|
147
|
+
reference_df = pd.read_csv(path, header=None)
|
|
148
|
+
annotations.update(
|
|
149
|
+
{k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if self._split in ["train", "val", None]:
|
|
153
|
+
annotations.update(
|
|
154
|
+
{
|
|
155
|
+
self._get_id_from_path(file_path): self._get_class_from_path(file_path)
|
|
156
|
+
for file_path in self._file_paths
|
|
157
|
+
if "test" not in file_path
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
return annotations
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
def prepare_data(self) -> None:
|
|
164
|
+
_validators.check_dataset_exists(self._root, False)
|
|
165
|
+
|
|
166
|
+
expected_directories = ["training/normal", "training/tumor", "testing/images"]
|
|
167
|
+
for resource in expected_directories:
|
|
168
|
+
if not os.path.isdir(os.path.join(self._root, resource)):
|
|
169
|
+
raise FileNotFoundError(f"'{resource}' not found in the root folder.")
|
|
170
|
+
|
|
171
|
+
if not os.path.isfile(os.path.join(self._root, "testing/reference.csv")):
|
|
172
|
+
raise FileNotFoundError("'reference.csv' file not found in the testing folder.")
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
def validate(self) -> None:
|
|
176
|
+
|
|
177
|
+
expected_n_files = {
|
|
178
|
+
"train": 216,
|
|
179
|
+
"val": 54,
|
|
180
|
+
"test": 129,
|
|
181
|
+
None: 399,
|
|
182
|
+
}
|
|
183
|
+
_validators.check_number_of_files(
|
|
184
|
+
self._file_paths, expected_n_files[self._split], self._split
|
|
185
|
+
)
|
|
186
|
+
_validators.check_dataset_integrity(
|
|
187
|
+
self,
|
|
188
|
+
length=None,
|
|
189
|
+
n_classes=2,
|
|
190
|
+
first_and_last_labels=("normal", "tumor"),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
@override
|
|
194
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
195
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
196
|
+
|
|
197
|
+
@override
|
|
198
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
199
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
200
|
+
return functional.to_image(image_array)
|
|
201
|
+
|
|
202
|
+
@override
|
|
203
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
204
|
+
file_path = self._file_paths[self._get_dataset_idx(index)]
|
|
205
|
+
class_name = self.annotations[self._get_id_from_path(file_path)]
|
|
206
|
+
return torch.tensor(self.class_to_idx[class_name], dtype=torch.int64)
|
|
207
|
+
|
|
208
|
+
@override
|
|
209
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
210
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
211
|
+
|
|
212
|
+
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
213
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
214
|
+
train_paths, val_paths = [], []
|
|
215
|
+
for path in glob.glob(os.path.join(self._root, "training/**/*.tif")):
|
|
216
|
+
if self._get_id_from_path(path) in self._val_slides:
|
|
217
|
+
val_paths.append(path)
|
|
218
|
+
else:
|
|
219
|
+
train_paths.append(path)
|
|
220
|
+
test_paths = glob.glob(os.path.join(self._root, "testing/images", "*.tif"))
|
|
221
|
+
|
|
222
|
+
match split:
|
|
223
|
+
case "train":
|
|
224
|
+
paths = train_paths
|
|
225
|
+
case "val":
|
|
226
|
+
paths = val_paths
|
|
227
|
+
case "test":
|
|
228
|
+
paths = test_paths
|
|
229
|
+
case None:
|
|
230
|
+
paths = train_paths + val_paths + test_paths
|
|
231
|
+
case _:
|
|
232
|
+
raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
|
|
233
|
+
return sorted([os.path.relpath(path, self._root) for path in paths])
|
|
234
|
+
|
|
235
|
+
def _get_id_from_path(self, file_path: str) -> str:
|
|
236
|
+
"""Extracts the slide ID from the file path."""
|
|
237
|
+
return os.path.basename(file_path).replace(".tif", "")
|
|
238
|
+
|
|
239
|
+
def _get_class_from_path(self, file_path: str) -> str:
|
|
240
|
+
"""Extracts the class name from the file path."""
|
|
241
|
+
class_name = self._get_id_from_path(file_path).split("_")[0]
|
|
242
|
+
if class_name not in self.classes:
|
|
243
|
+
raise ValueError(f"Invalid class name '{class_name}' in file path '{file_path}'.")
|
|
244
|
+
return class_name
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.datasets import folder, utils
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -37,8 +38,7 @@ class CRC(base.ImageClassification):
|
|
|
37
38
|
root: str,
|
|
38
39
|
split: Literal["train", "val"],
|
|
39
40
|
download: bool = False,
|
|
40
|
-
|
|
41
|
-
target_transforms: Callable | None = None,
|
|
41
|
+
transforms: Callable | None = None,
|
|
42
42
|
) -> None:
|
|
43
43
|
"""Initializes the dataset.
|
|
44
44
|
|
|
@@ -56,15 +56,10 @@ class CRC(base.ImageClassification):
|
|
|
56
56
|
Note that the download will be executed only by additionally
|
|
57
57
|
calling the :meth:`prepare_data` method and if the data does
|
|
58
58
|
not yet exist on disk.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
target_transforms: A function/transform that takes in the target
|
|
62
|
-
and transforms it.
|
|
59
|
+
transforms: A function/transform which returns a transformed
|
|
60
|
+
version of the raw data samples.
|
|
63
61
|
"""
|
|
64
|
-
super().__init__(
|
|
65
|
-
image_transforms=image_transforms,
|
|
66
|
-
target_transforms=target_transforms,
|
|
67
|
-
)
|
|
62
|
+
super().__init__(transforms=transforms)
|
|
68
63
|
|
|
69
64
|
self._root = root
|
|
70
65
|
self._split = split
|
|
@@ -122,14 +117,14 @@ class CRC(base.ImageClassification):
|
|
|
122
117
|
)
|
|
123
118
|
|
|
124
119
|
@override
|
|
125
|
-
def load_image(self, index: int) ->
|
|
120
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
126
121
|
image_path, _ = self._samples[index]
|
|
127
|
-
return io.
|
|
122
|
+
return io.read_image_as_tensor(image_path)
|
|
128
123
|
|
|
129
124
|
@override
|
|
130
|
-
def load_target(self, index: int) ->
|
|
125
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
131
126
|
_, target = self._samples[index]
|
|
132
|
-
return
|
|
127
|
+
return torch.tensor(target, dtype=torch.long)
|
|
133
128
|
|
|
134
129
|
@override
|
|
135
130
|
def __len__(self) -> int:
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.vision.data.datasets import _validators
|
|
@@ -18,23 +19,17 @@ class MHIST(base.ImageClassification):
|
|
|
18
19
|
self,
|
|
19
20
|
root: str,
|
|
20
21
|
split: Literal["train", "test"],
|
|
21
|
-
|
|
22
|
-
target_transforms: Callable | None = None,
|
|
22
|
+
transforms: Callable | None = None,
|
|
23
23
|
) -> None:
|
|
24
24
|
"""Initialize the dataset.
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
27
|
root: Path to the root directory of the dataset.
|
|
28
28
|
split: Dataset split to use.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
target_transforms: A function/transform that takes in the target
|
|
32
|
-
and transforms it.
|
|
29
|
+
transforms: A function/transform which returns a transformed
|
|
30
|
+
version of the raw data samples.
|
|
33
31
|
"""
|
|
34
|
-
super().__init__(
|
|
35
|
-
image_transforms=image_transforms,
|
|
36
|
-
target_transforms=target_transforms,
|
|
37
|
-
)
|
|
32
|
+
super().__init__(transforms=transforms)
|
|
38
33
|
|
|
39
34
|
self._root = root
|
|
40
35
|
self._split = split
|
|
@@ -74,16 +69,16 @@ class MHIST(base.ImageClassification):
|
|
|
74
69
|
)
|
|
75
70
|
|
|
76
71
|
@override
|
|
77
|
-
def load_image(self, index: int) ->
|
|
72
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
78
73
|
image_filename, _ = self._samples[index]
|
|
79
74
|
image_path = os.path.join(self._dataset_path, image_filename)
|
|
80
|
-
return io.
|
|
75
|
+
return io.read_image_as_tensor(image_path)
|
|
81
76
|
|
|
82
77
|
@override
|
|
83
|
-
def load_target(self, index: int) ->
|
|
78
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
84
79
|
_, label = self._samples[index]
|
|
85
80
|
target = self.class_to_idx[label]
|
|
86
|
-
return
|
|
81
|
+
return torch.tensor(target, dtype=torch.float32)
|
|
87
82
|
|
|
88
83
|
@override
|
|
89
84
|
def __len__(self) -> int:
|