kaiko-eva 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/writers/embeddings/base.py +3 -4
- eva/core/data/dataloaders/dataloader.py +2 -2
- eva/core/data/splitting/random.py +6 -5
- eva/core/data/splitting/stratified.py +12 -6
- eva/core/losses/__init__.py +5 -0
- eva/core/losses/cross_entropy.py +27 -0
- eva/core/metrics/__init__.py +0 -4
- eva/core/metrics/defaults/__init__.py +0 -2
- eva/core/models/modules/module.py +9 -9
- eva/core/models/transforms/extract_cls_features.py +17 -9
- eva/core/models/transforms/extract_patch_features.py +23 -11
- eva/core/utils/progress_bar.py +15 -0
- eva/vision/data/datasets/__init__.py +4 -0
- eva/vision/data/datasets/classification/__init__.py +2 -1
- eva/vision/data/datasets/classification/camelyon16.py +4 -1
- eva/vision/data/datasets/classification/panda.py +17 -1
- eva/vision/data/datasets/classification/wsi.py +4 -1
- eva/vision/data/datasets/segmentation/__init__.py +2 -0
- eva/vision/data/datasets/segmentation/consep.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +49 -29
- eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
- eva/vision/data/datasets/segmentation/monusac.py +7 -7
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +2 -2
- eva/vision/data/datasets/wsi.py +37 -1
- eva/vision/data/wsi/patching/coordinates.py +9 -1
- eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
- eva/vision/data/wsi/patching/samplers/random.py +4 -2
- eva/vision/losses/__init__.py +2 -2
- eva/vision/losses/dice.py +75 -8
- eva/vision/metrics/__init__.py +11 -0
- eva/vision/metrics/defaults/__init__.py +7 -0
- eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
- eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
- eva/vision/metrics/segmentation/BUILD +1 -0
- eva/vision/metrics/segmentation/__init__.py +9 -0
- eva/vision/metrics/segmentation/_utils.py +69 -0
- eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
- eva/vision/metrics/segmentation/mean_iou.py +57 -0
- eva/vision/models/modules/semantic_segmentation.py +4 -3
- eva/vision/models/networks/backbones/_utils.py +12 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
- eva/vision/models/networks/backbones/pathology/histai.py +8 -2
- eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
- eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
- eva/vision/models/networks/backbones/pathology/paige.py +51 -0
- eva/vision/models/networks/decoders/__init__.py +1 -1
- eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
- eva/vision/models/networks/decoders/segmentation/base.py +16 -0
- eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
- eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
- eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
- eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
- eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
- eva/vision/utils/io/__init__.py +7 -1
- eva/vision/utils/io/nifti.py +19 -4
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/METADATA +3 -34
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/RECORD +61 -48
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/WHEEL +1 -1
- eva/core/metrics/mean_iou.py +0 -120
- eva/vision/models/networks/decoders/decoder.py +0 -7
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -172,15 +172,14 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
|
|
|
172
172
|
|
|
173
173
|
def _check_if_exists(self) -> None:
|
|
174
174
|
"""Checks if the output directory already exists and if it should be overwritten."""
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
except FileExistsError as e:
|
|
175
|
+
os.makedirs(self._output_dir, exist_ok=True)
|
|
176
|
+
if os.path.exists(os.path.join(self._output_dir, "manifest.csv")) and not self._overwrite:
|
|
178
177
|
raise FileExistsError(
|
|
179
178
|
f"The embeddings output directory already exists: {self._output_dir}. This "
|
|
180
179
|
"either means that they have been computed before or that a wrong output "
|
|
181
180
|
"directory is being used. Consider using `eva fit` instead, selecting a "
|
|
182
181
|
"different output directory or setting overwrite=True."
|
|
183
|
-
)
|
|
182
|
+
)
|
|
184
183
|
os.makedirs(self._output_dir, exist_ok=True)
|
|
185
184
|
|
|
186
185
|
|
|
@@ -38,7 +38,7 @@ class DataLoader:
|
|
|
38
38
|
Mutually exclusive with `batch_size`, `shuffle`, `sampler` and `drop_last`.
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
|
-
num_workers: int =
|
|
41
|
+
num_workers: int | None = None
|
|
42
42
|
"""How many workers to use for loading the data.
|
|
43
43
|
|
|
44
44
|
By default, it will use the number of CPUs available.
|
|
@@ -71,7 +71,7 @@ class DataLoader:
|
|
|
71
71
|
shuffle=self.shuffle,
|
|
72
72
|
sampler=self.sampler,
|
|
73
73
|
batch_sampler=self.batch_sampler,
|
|
74
|
-
num_workers=self.num_workers,
|
|
74
|
+
num_workers=self.num_workers or multiprocessing.cpu_count(),
|
|
75
75
|
collate_fn=self.collate_fn,
|
|
76
76
|
pin_memory=self.pin_memory,
|
|
77
77
|
drop_last=self.drop_last,
|
|
@@ -24,12 +24,13 @@ def random_split(
|
|
|
24
24
|
Returns:
|
|
25
25
|
The indices of the train, validation, and test sets as lists.
|
|
26
26
|
"""
|
|
27
|
-
|
|
28
|
-
|
|
27
|
+
total_ratio = train_ratio + val_ratio + test_ratio
|
|
28
|
+
if total_ratio > 1.0:
|
|
29
|
+
raise ValueError("The sum of the ratios must be lower or equal to 1.")
|
|
29
30
|
|
|
30
|
-
np.random.
|
|
31
|
-
n_samples = len(samples)
|
|
32
|
-
indices =
|
|
31
|
+
random_generator = np.random.default_rng(seed)
|
|
32
|
+
n_samples = int(total_ratio * len(samples))
|
|
33
|
+
indices = random_generator.permutation(len(samples))[:n_samples]
|
|
33
34
|
|
|
34
35
|
n_train = int(np.floor(train_ratio * n_samples))
|
|
35
36
|
n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
|
|
@@ -28,10 +28,11 @@ def stratified_split(
|
|
|
28
28
|
"""
|
|
29
29
|
if len(samples) != len(targets):
|
|
30
30
|
raise ValueError("The number of samples and targets must be equal.")
|
|
31
|
-
if train_ratio + val_ratio + (test_ratio or 0)
|
|
32
|
-
raise ValueError("The sum of the ratios must be equal to 1.")
|
|
31
|
+
if train_ratio + val_ratio + (test_ratio or 0) > 1.0:
|
|
32
|
+
raise ValueError("The sum of the ratios must be lower or equal to 1.")
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
use_all_samples = train_ratio + val_ratio + test_ratio == 1
|
|
35
|
+
random_generator = np.random.default_rng(seed)
|
|
35
36
|
unique_classes, y_indices = np.unique(targets, return_inverse=True)
|
|
36
37
|
n_classes = unique_classes.shape[0]
|
|
37
38
|
|
|
@@ -39,18 +40,23 @@ def stratified_split(
|
|
|
39
40
|
|
|
40
41
|
for c in range(n_classes):
|
|
41
42
|
class_indices = np.where(y_indices == c)[0]
|
|
42
|
-
|
|
43
|
+
random_generator.shuffle(class_indices)
|
|
43
44
|
|
|
44
45
|
n_train = int(np.floor(train_ratio * len(class_indices))) or 1
|
|
45
46
|
n_val = (
|
|
46
47
|
len(class_indices) - n_train
|
|
47
|
-
if test_ratio == 0.0
|
|
48
|
+
if test_ratio == 0.0 and use_all_samples
|
|
48
49
|
else int(np.floor(val_ratio * len(class_indices))) or 1
|
|
49
50
|
)
|
|
50
51
|
|
|
51
52
|
train_indices.extend(class_indices[:n_train])
|
|
52
53
|
val_indices.extend(class_indices[n_train : n_train + n_val])
|
|
53
54
|
if test_ratio > 0.0:
|
|
54
|
-
|
|
55
|
+
n_test = (
|
|
56
|
+
len(class_indices) - n_train - n_val
|
|
57
|
+
if use_all_samples
|
|
58
|
+
else int(np.floor(test_ratio * len(class_indices))) or 1
|
|
59
|
+
)
|
|
60
|
+
test_indices.extend(class_indices[n_train + n_val : n_train + n_val + n_test])
|
|
55
61
|
|
|
56
62
|
return train_indices, val_indices, test_indices or None
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Cross-entropy based loss function."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CrossEntropyLoss(nn.CrossEntropyLoss):
|
|
10
|
+
"""A wrapper around torch.nn.CrossEntropyLoss that accepts weights in list format.
|
|
11
|
+
|
|
12
|
+
Needed for .yaml file loading & class instantiation with jsonarparse.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, *args, weight: Sequence[float] | torch.Tensor | None = None, **kwargs
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Initialize the loss function.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
args: Positional arguments from the base class.
|
|
22
|
+
weight: A list of weights to assign to each class.
|
|
23
|
+
kwargs: Key-word arguments from the base class.
|
|
24
|
+
"""
|
|
25
|
+
if weight is not None and not isinstance(weight, torch.Tensor):
|
|
26
|
+
weight = torch.tensor(weight)
|
|
27
|
+
super().__init__(*args, **kwargs, weight=weight)
|
eva/core/metrics/__init__.py
CHANGED
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
from eva.core.metrics.average_loss import AverageLoss
|
|
4
4
|
from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
|
|
5
5
|
from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
|
|
6
|
-
from eva.core.metrics.generalized_dice import GeneralizedDiceScore
|
|
7
|
-
from eva.core.metrics.mean_iou import MeanIoU
|
|
8
6
|
from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
|
|
9
7
|
|
|
10
8
|
__all__ = [
|
|
@@ -12,8 +10,6 @@ __all__ = [
|
|
|
12
10
|
"BinaryBalancedAccuracy",
|
|
13
11
|
"BinaryClassificationMetrics",
|
|
14
12
|
"MulticlassClassificationMetrics",
|
|
15
|
-
"GeneralizedDiceScore",
|
|
16
|
-
"MeanIoU",
|
|
17
13
|
"Metric",
|
|
18
14
|
"MetricCollection",
|
|
19
15
|
"MetricModule",
|
|
@@ -4,10 +4,8 @@ from eva.core.metrics.defaults.classification import (
|
|
|
4
4
|
BinaryClassificationMetrics,
|
|
5
5
|
MulticlassClassificationMetrics,
|
|
6
6
|
)
|
|
7
|
-
from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
8
7
|
|
|
9
8
|
__all__ = [
|
|
10
9
|
"MulticlassClassificationMetrics",
|
|
11
10
|
"BinaryClassificationMetrics",
|
|
12
|
-
"MulticlassSegmentationMetrics",
|
|
13
11
|
]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Base model module."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import Any, Mapping
|
|
4
5
|
|
|
5
6
|
import lightning.pytorch as pl
|
|
6
7
|
import torch
|
|
7
|
-
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
|
|
8
8
|
from lightning.pytorch.utilities import memory
|
|
9
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
10
10
|
from typing_extensions import override
|
|
@@ -49,14 +49,14 @@ class ModelModule(pl.LightningModule):
|
|
|
49
49
|
|
|
50
50
|
@property
|
|
51
51
|
def metrics_device(self) -> torch.device:
|
|
52
|
-
"""Returns the device by which the metrics should be calculated.
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
return
|
|
52
|
+
"""Returns the device by which the metrics should be calculated."""
|
|
53
|
+
device = os.getenv("METRICS_DEVICE", None)
|
|
54
|
+
if device is not None:
|
|
55
|
+
return torch.device(device)
|
|
56
|
+
elif self.device.type == "mps":
|
|
57
|
+
# mps seems to have compatibility issues with segmentation metrics
|
|
58
|
+
return torch.device("cpu")
|
|
59
|
+
return self.device
|
|
60
60
|
|
|
61
61
|
@override
|
|
62
62
|
def on_fit_start(self) -> None:
|
|
@@ -7,13 +7,20 @@ from transformers import modeling_outputs
|
|
|
7
7
|
class ExtractCLSFeatures:
|
|
8
8
|
"""Extracts the CLS token from a ViT model output."""
|
|
9
9
|
|
|
10
|
-
def __init__(
|
|
10
|
+
def __init__(
|
|
11
|
+
self, cls_index: int = 0, num_register_tokens: int = 0, include_patch_tokens: bool = False
|
|
12
|
+
) -> None:
|
|
11
13
|
"""Initializes the transformation.
|
|
12
14
|
|
|
13
15
|
Args:
|
|
14
16
|
cls_index: The index of the CLS token in the output tensor.
|
|
17
|
+
num_register_tokens: The number of register tokens in the model output.
|
|
18
|
+
include_patch_tokens: Whether to concat the mean aggregated patch tokens with
|
|
19
|
+
the cls token.
|
|
15
20
|
"""
|
|
16
21
|
self._cls_index = cls_index
|
|
22
|
+
self._num_register_tokens = num_register_tokens
|
|
23
|
+
self._include_patch_tokens = include_patch_tokens
|
|
17
24
|
|
|
18
25
|
def __call__(
|
|
19
26
|
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
@@ -23,11 +30,12 @@ class ExtractCLSFeatures:
|
|
|
23
30
|
Args:
|
|
24
31
|
tensor: The tensor representing the model output.
|
|
25
32
|
"""
|
|
26
|
-
if isinstance(tensor,
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
34
|
+
tensor = tensor.last_hidden_state
|
|
35
|
+
|
|
36
|
+
cls_token = tensor[:, self._cls_index, :]
|
|
37
|
+
if self._include_patch_tokens:
|
|
38
|
+
patch_tokens = tensor[:, 1 + self._num_register_tokens :, :]
|
|
39
|
+
return torch.cat([cls_token, patch_tokens.mean(1)], dim=-1)
|
|
40
|
+
|
|
41
|
+
return cls_token
|
|
@@ -10,13 +10,23 @@ from transformers import modeling_outputs
|
|
|
10
10
|
class ExtractPatchFeatures:
|
|
11
11
|
"""Extracts the patch features from a ViT model output."""
|
|
12
12
|
|
|
13
|
-
def __init__(
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
has_cls_token: bool = True,
|
|
16
|
+
num_register_tokens: int = 0,
|
|
17
|
+
ignore_remaining_dims: bool = False,
|
|
18
|
+
) -> None:
|
|
14
19
|
"""Initializes the transformation.
|
|
15
20
|
|
|
16
21
|
Args:
|
|
22
|
+
has_cls_token: If set to `True`, the model output is expected to have
|
|
23
|
+
a classification token.
|
|
24
|
+
num_register_tokens: The number of register tokens in the model output.
|
|
17
25
|
ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
|
|
18
26
|
of the patch grid if it is not a square number.
|
|
19
27
|
"""
|
|
28
|
+
self._has_cls_token = has_cls_token
|
|
29
|
+
self._num_register_tokens = num_register_tokens
|
|
20
30
|
self._ignore_remaining_dims = ignore_remaining_dims
|
|
21
31
|
|
|
22
32
|
def __call__(
|
|
@@ -31,17 +41,19 @@ class ExtractPatchFeatures:
|
|
|
31
41
|
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
|
|
32
42
|
representing the model output.
|
|
33
43
|
"""
|
|
44
|
+
num_skip = int(self._has_cls_token) + self._num_register_tokens
|
|
34
45
|
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
35
|
-
features = tensor.last_hidden_state[:,
|
|
36
|
-
batch_size, hidden_size, patch_grid = features.shape
|
|
37
|
-
height = width = int(math.sqrt(patch_grid))
|
|
38
|
-
if height * width != patch_grid:
|
|
39
|
-
if self._ignore_remaining_dims:
|
|
40
|
-
features = features[:, :, : height * width]
|
|
41
|
-
else:
|
|
42
|
-
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
|
|
43
|
-
patch_embeddings = features.view(batch_size, hidden_size, height, width)
|
|
46
|
+
features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1)
|
|
44
47
|
else:
|
|
45
|
-
|
|
48
|
+
features = tensor[:, num_skip:, :].permute(0, 2, 1)
|
|
49
|
+
|
|
50
|
+
batch_size, hidden_size, patch_grid = features.shape
|
|
51
|
+
height = width = int(math.sqrt(patch_grid))
|
|
52
|
+
if height * width != patch_grid:
|
|
53
|
+
if self._ignore_remaining_dims:
|
|
54
|
+
features = features[:, :, -height * width :]
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
|
|
57
|
+
patch_embeddings = features.view(batch_size, hidden_size, height, width)
|
|
46
58
|
|
|
47
59
|
return [patch_embeddings]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Progress bar utility functions."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from tqdm import tqdm as _tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def tqdm(*args, **kwargs) -> _tqdm:
|
|
9
|
+
"""Wrapper function for `tqdm.tqdm`."""
|
|
10
|
+
refresh_rate = os.environ.get("TQDM_REFRESH_RATE")
|
|
11
|
+
refresh_rate = int(refresh_rate) if refresh_rate is not None else None
|
|
12
|
+
disable = bool(int(os.environ.get("TQDM_DISABLE", 0))) or (refresh_rate == 0)
|
|
13
|
+
kwargs.setdefault("disable", disable)
|
|
14
|
+
kwargs.setdefault("miniters", refresh_rate)
|
|
15
|
+
return _tqdm(*args, **kwargs)
|
|
@@ -6,6 +6,7 @@ from eva.vision.data.datasets.classification import (
|
|
|
6
6
|
MHIST,
|
|
7
7
|
PANDA,
|
|
8
8
|
Camelyon16,
|
|
9
|
+
PANDASmall,
|
|
9
10
|
PatchCamelyon,
|
|
10
11
|
WsiClassificationDataset,
|
|
11
12
|
)
|
|
@@ -15,6 +16,7 @@ from eva.vision.data.datasets.segmentation import (
|
|
|
15
16
|
EmbeddingsSegmentationDataset,
|
|
16
17
|
ImageSegmentation,
|
|
17
18
|
LiTS,
|
|
19
|
+
LiTSBalanced,
|
|
18
20
|
MoNuSAC,
|
|
19
21
|
TotalSegmentator2D,
|
|
20
22
|
)
|
|
@@ -27,6 +29,7 @@ __all__ = [
|
|
|
27
29
|
"CRC",
|
|
28
30
|
"MHIST",
|
|
29
31
|
"PANDA",
|
|
32
|
+
"PANDASmall",
|
|
30
33
|
"Camelyon16",
|
|
31
34
|
"PatchCamelyon",
|
|
32
35
|
"WsiClassificationDataset",
|
|
@@ -34,6 +37,7 @@ __all__ = [
|
|
|
34
37
|
"EmbeddingsSegmentationDataset",
|
|
35
38
|
"ImageSegmentation",
|
|
36
39
|
"LiTS",
|
|
40
|
+
"LiTSBalanced",
|
|
37
41
|
"MoNuSAC",
|
|
38
42
|
"TotalSegmentator2D",
|
|
39
43
|
"VisionDataset",
|
|
@@ -4,7 +4,7 @@ from eva.vision.data.datasets.classification.bach import BACH
|
|
|
4
4
|
from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
|
|
5
5
|
from eva.vision.data.datasets.classification.crc import CRC
|
|
6
6
|
from eva.vision.data.datasets.classification.mhist import MHIST
|
|
7
|
-
from eva.vision.data.datasets.classification.panda import PANDA
|
|
7
|
+
from eva.vision.data.datasets.classification.panda import PANDA, PANDASmall
|
|
8
8
|
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
|
|
9
9
|
from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
|
|
10
10
|
|
|
@@ -15,5 +15,6 @@ __all__ = [
|
|
|
15
15
|
"PatchCamelyon",
|
|
16
16
|
"WsiClassificationDataset",
|
|
17
17
|
"PANDA",
|
|
18
|
+
"PANDASmall",
|
|
18
19
|
"Camelyon16",
|
|
19
20
|
]
|
|
@@ -87,6 +87,7 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
87
87
|
target_mpp: float = 0.5,
|
|
88
88
|
backend: str = "openslide",
|
|
89
89
|
image_transforms: Callable | None = None,
|
|
90
|
+
coords_path: str | None = None,
|
|
90
91
|
seed: int = 42,
|
|
91
92
|
) -> None:
|
|
92
93
|
"""Initializes the dataset.
|
|
@@ -100,6 +101,7 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
100
101
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
101
102
|
backend: The backend to use for reading the whole-slide images.
|
|
102
103
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
104
|
+
coords_path: File path to save the patch coordinates as .csv.
|
|
103
105
|
seed: Random seed for reproducibility.
|
|
104
106
|
"""
|
|
105
107
|
self._split = split
|
|
@@ -119,6 +121,7 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
119
121
|
target_mpp=target_mpp,
|
|
120
122
|
backend=backend,
|
|
121
123
|
image_transforms=image_transforms,
|
|
124
|
+
coords_path=coords_path,
|
|
122
125
|
)
|
|
123
126
|
|
|
124
127
|
@property
|
|
@@ -207,7 +210,7 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
207
210
|
|
|
208
211
|
@override
|
|
209
212
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
210
|
-
return
|
|
213
|
+
return wsi.MultiWsiDataset.load_metadata(self, index)
|
|
211
214
|
|
|
212
215
|
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
213
216
|
"""Loads the file paths of the corresponding dataset split."""
|
|
@@ -49,6 +49,7 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
49
49
|
target_mpp: float = 0.5,
|
|
50
50
|
backend: str = "openslide",
|
|
51
51
|
image_transforms: Callable | None = None,
|
|
52
|
+
coords_path: str | None = None,
|
|
52
53
|
seed: int = 42,
|
|
53
54
|
) -> None:
|
|
54
55
|
"""Initializes the dataset.
|
|
@@ -62,6 +63,7 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
62
63
|
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
63
64
|
backend: The backend to use for reading the whole-slide images.
|
|
64
65
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
66
|
+
coords_path: File path to save the patch coordinates as .csv.
|
|
65
67
|
seed: Random seed for reproducibility.
|
|
66
68
|
"""
|
|
67
69
|
self._split = split
|
|
@@ -80,6 +82,7 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
80
82
|
target_mpp=target_mpp,
|
|
81
83
|
backend=backend,
|
|
82
84
|
image_transforms=image_transforms,
|
|
85
|
+
coords_path=coords_path,
|
|
83
86
|
)
|
|
84
87
|
|
|
85
88
|
@property
|
|
@@ -132,7 +135,7 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
132
135
|
|
|
133
136
|
@override
|
|
134
137
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
135
|
-
return
|
|
138
|
+
return wsi.MultiWsiDataset.load_metadata(self, index)
|
|
136
139
|
|
|
137
140
|
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
138
141
|
"""Loads the file paths of the corresponding dataset split."""
|
|
@@ -182,3 +185,16 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
182
185
|
|
|
183
186
|
def _get_id_from_path(self, file_path: str) -> str:
|
|
184
187
|
return os.path.basename(file_path).replace(".tiff", "")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class PANDASmall(PANDA):
|
|
191
|
+
"""Small version of the PANDA dataset for quicker benchmarking."""
|
|
192
|
+
|
|
193
|
+
_train_split_ratio: float = 0.1
|
|
194
|
+
"""Train split ratio."""
|
|
195
|
+
|
|
196
|
+
_val_split_ratio: float = 0.05
|
|
197
|
+
"""Validation split ratio."""
|
|
198
|
+
|
|
199
|
+
_test_split_ratio: float = 0.05
|
|
200
|
+
"""Test split ratio."""
|
|
@@ -35,6 +35,7 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
35
35
|
split: Literal["train", "val", "test"] | None = None,
|
|
36
36
|
image_transforms: Callable | None = None,
|
|
37
37
|
column_mapping: Dict[str, str] = default_column_mapping,
|
|
38
|
+
coords_path: str | None = None,
|
|
38
39
|
):
|
|
39
40
|
"""Initializes the dataset.
|
|
40
41
|
|
|
@@ -51,6 +52,7 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
51
52
|
split: The split of the dataset to load.
|
|
52
53
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
53
54
|
column_mapping: Mapping of the columns in the manifest file.
|
|
55
|
+
coords_path: File path to save the patch coordinates as .csv.
|
|
54
56
|
"""
|
|
55
57
|
self._split = split
|
|
56
58
|
self._column_mapping = self.default_column_mapping | column_mapping
|
|
@@ -66,6 +68,7 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
66
68
|
target_mpp=target_mpp,
|
|
67
69
|
backend=backend,
|
|
68
70
|
image_transforms=image_transforms,
|
|
71
|
+
coords_path=coords_path,
|
|
69
72
|
)
|
|
70
73
|
|
|
71
74
|
@override
|
|
@@ -88,7 +91,7 @@ class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
88
91
|
|
|
89
92
|
@override
|
|
90
93
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
91
|
-
return
|
|
94
|
+
return wsi.MultiWsiDataset.load_metadata(self, index)
|
|
92
95
|
|
|
93
96
|
def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
|
|
94
97
|
df = pd.read_csv(manifest_path)
|
|
@@ -5,6 +5,7 @@ from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
|
5
5
|
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
6
|
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
7
|
from eva.vision.data.datasets.segmentation.lits import LiTS
|
|
8
|
+
from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced
|
|
8
9
|
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
9
10
|
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
10
11
|
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"CoNSeP",
|
|
15
16
|
"EmbeddingsSegmentationDataset",
|
|
16
17
|
"LiTS",
|
|
18
|
+
"LiTSBalanced",
|
|
17
19
|
"MoNuSAC",
|
|
18
20
|
"TotalSegmentator2D",
|
|
19
21
|
]
|
|
@@ -37,8 +37,8 @@ class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
|
37
37
|
root: str,
|
|
38
38
|
sampler: samplers.Sampler | None = None,
|
|
39
39
|
split: Literal["train", "val"] | None = None,
|
|
40
|
-
width: int =
|
|
41
|
-
height: int =
|
|
40
|
+
width: int = 250,
|
|
41
|
+
height: int = 250,
|
|
42
42
|
target_mpp: float = 0.25,
|
|
43
43
|
transforms: Callable | None = None,
|
|
44
44
|
) -> None:
|