kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.1.dist-info/METADATA +553 -0
- kaiko_eva-0.1.1.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Base dataset class for Embeddings."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
+
import multiprocessing
|
|
4
5
|
import os
|
|
5
|
-
from typing import Callable, Dict, Literal, Tuple
|
|
6
|
+
from typing import Callable, Dict, Generic, Literal, Tuple, TypeVar
|
|
6
7
|
|
|
7
|
-
import numpy as np
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import torch
|
|
10
10
|
from typing_extensions import override
|
|
@@ -12,16 +12,20 @@ from typing_extensions import override
|
|
|
12
12
|
from eva.core.data.datasets import base
|
|
13
13
|
from eva.core.utils import io
|
|
14
14
|
|
|
15
|
+
TargetType = TypeVar("TargetType")
|
|
16
|
+
"""The target data type."""
|
|
17
|
+
|
|
18
|
+
|
|
15
19
|
default_column_mapping: Dict[str, str] = {
|
|
16
20
|
"path": "embeddings",
|
|
17
21
|
"target": "target",
|
|
18
22
|
"split": "split",
|
|
19
|
-
"multi_id": "
|
|
23
|
+
"multi_id": "wsi_id",
|
|
20
24
|
}
|
|
21
25
|
"""The default column mapping of the variables to the manifest columns."""
|
|
22
26
|
|
|
23
27
|
|
|
24
|
-
class EmbeddingsDataset(base.Dataset):
|
|
28
|
+
class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
|
|
25
29
|
"""Abstract base class for embedding datasets."""
|
|
26
30
|
|
|
27
31
|
def __init__(
|
|
@@ -62,31 +66,7 @@ class EmbeddingsDataset(base.Dataset):
|
|
|
62
66
|
|
|
63
67
|
self._data: pd.DataFrame
|
|
64
68
|
|
|
65
|
-
|
|
66
|
-
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
67
|
-
"""Returns the `index`'th embedding sample.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
index: The index of the data sample to load.
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
The embedding sample as a tensor.
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
@abc.abstractmethod
|
|
77
|
-
def _load_target(self, index: int) -> np.ndarray:
|
|
78
|
-
"""Returns the `index`'th target sample.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
index: The index of the data sample to load.
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
The sample target as an array.
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
@abc.abstractmethod
|
|
88
|
-
def __len__(self) -> int:
|
|
89
|
-
"""Returns the total length of the data."""
|
|
69
|
+
self._set_multiprocessing_start_method()
|
|
90
70
|
|
|
91
71
|
def filename(self, index: int) -> str:
|
|
92
72
|
"""Returns the filename of the `index`'th data sample.
|
|
@@ -105,7 +85,11 @@ class EmbeddingsDataset(base.Dataset):
|
|
|
105
85
|
def setup(self):
|
|
106
86
|
self._data = self._load_manifest()
|
|
107
87
|
|
|
108
|
-
|
|
88
|
+
@abc.abstractmethod
|
|
89
|
+
def __len__(self) -> int:
|
|
90
|
+
"""Returns the total length of the data."""
|
|
91
|
+
|
|
92
|
+
def __getitem__(self, index) -> Tuple[torch.Tensor, TargetType]:
|
|
109
93
|
"""Returns the `index`'th data sample.
|
|
110
94
|
|
|
111
95
|
Args:
|
|
@@ -118,6 +102,28 @@ class EmbeddingsDataset(base.Dataset):
|
|
|
118
102
|
target = self._load_target(index)
|
|
119
103
|
return self._apply_transforms(embeddings, target)
|
|
120
104
|
|
|
105
|
+
@abc.abstractmethod
|
|
106
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
107
|
+
"""Returns the `index`'th embedding sample.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
index: The index of the data sample to load.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The embedding sample as a tensor.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
@abc.abstractmethod
|
|
117
|
+
def _load_target(self, index: int) -> TargetType:
|
|
118
|
+
"""Returns the `index`'th target sample.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
index: The index of the data sample to load.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The sample target as an array.
|
|
125
|
+
"""
|
|
126
|
+
|
|
121
127
|
def _load_manifest(self) -> pd.DataFrame:
|
|
122
128
|
"""Loads manifest file and filters the data based on the split column.
|
|
123
129
|
|
|
@@ -132,8 +138,8 @@ class EmbeddingsDataset(base.Dataset):
|
|
|
132
138
|
return data
|
|
133
139
|
|
|
134
140
|
def _apply_transforms(
|
|
135
|
-
self, embeddings: torch.Tensor, target:
|
|
136
|
-
) -> Tuple[torch.Tensor,
|
|
141
|
+
self, embeddings: torch.Tensor, target: TargetType
|
|
142
|
+
) -> Tuple[torch.Tensor, TargetType]:
|
|
137
143
|
"""Applies the transforms to the provided data and returns them.
|
|
138
144
|
|
|
139
145
|
Args:
|
|
@@ -150,3 +156,12 @@ class EmbeddingsDataset(base.Dataset):
|
|
|
150
156
|
target = self._target_transforms(target)
|
|
151
157
|
|
|
152
158
|
return embeddings, target
|
|
159
|
+
|
|
160
|
+
def _set_multiprocessing_start_method(self):
|
|
161
|
+
"""Sets the multiprocessing start method to spawn.
|
|
162
|
+
|
|
163
|
+
If the start method is not set explicitly, the torch data loaders will
|
|
164
|
+
use the OS default method, which for some unix systems is `fork` and
|
|
165
|
+
can lead to runtime issues such as deadlocks in this context.
|
|
166
|
+
"""
|
|
167
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Functions for random splitting."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def random_split(
|
|
9
|
+
samples: Sequence[Any],
|
|
10
|
+
train_ratio: float,
|
|
11
|
+
val_ratio: float,
|
|
12
|
+
test_ratio: float = 0.0,
|
|
13
|
+
seed: int = 42,
|
|
14
|
+
) -> Tuple[List[int], List[int], List[int] | None]:
|
|
15
|
+
"""Splits the samples into random train, validation, and test (optional) sets.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
samples: The samples to split.
|
|
19
|
+
train_ratio: The ratio of the training set.
|
|
20
|
+
val_ratio: The ratio of the validation set.
|
|
21
|
+
test_ratio: The ratio of the test set (optional).
|
|
22
|
+
seed: The seed for reproducibility.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The indices of the train, validation, and test sets as lists.
|
|
26
|
+
"""
|
|
27
|
+
if train_ratio + val_ratio + (test_ratio or 0) != 1:
|
|
28
|
+
raise ValueError("The sum of the ratios must be equal to 1.")
|
|
29
|
+
|
|
30
|
+
np.random.seed(seed)
|
|
31
|
+
n_samples = len(samples)
|
|
32
|
+
indices = np.random.permutation(n_samples)
|
|
33
|
+
|
|
34
|
+
n_train = int(np.floor(train_ratio * n_samples))
|
|
35
|
+
n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
|
|
36
|
+
|
|
37
|
+
train_indices = list(indices[:n_train])
|
|
38
|
+
val_indices = list(indices[n_train : n_train + n_val])
|
|
39
|
+
test_indices = list(indices[n_train + n_val :]) if test_ratio > 0.0 else None
|
|
40
|
+
|
|
41
|
+
return train_indices, val_indices, test_indices
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Functions for stratified splitting."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def stratified_split(
|
|
9
|
+
samples: Sequence[Any],
|
|
10
|
+
targets: Sequence[Any],
|
|
11
|
+
train_ratio: float,
|
|
12
|
+
val_ratio: float,
|
|
13
|
+
test_ratio: float = 0.0,
|
|
14
|
+
seed: int = 42,
|
|
15
|
+
) -> Tuple[List[int], List[int], List[int] | None]:
|
|
16
|
+
"""Splits the samples into stratified train, validation, and test (optional) sets.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
samples: The samples to split.
|
|
20
|
+
targets: The corresponding targets used for stratification.
|
|
21
|
+
train_ratio: The ratio of the training set.
|
|
22
|
+
val_ratio: The ratio of the validation set.
|
|
23
|
+
test_ratio: The ratio of the test set (optional).
|
|
24
|
+
seed: The seed for reproducibility.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The indices of the train, validation, and test sets.
|
|
28
|
+
"""
|
|
29
|
+
if len(samples) != len(targets):
|
|
30
|
+
raise ValueError("The number of samples and targets must be equal.")
|
|
31
|
+
if train_ratio + val_ratio + (test_ratio or 0) != 1:
|
|
32
|
+
raise ValueError("The sum of the ratios must be equal to 1.")
|
|
33
|
+
|
|
34
|
+
np.random.seed(seed)
|
|
35
|
+
unique_classes, y_indices = np.unique(targets, return_inverse=True)
|
|
36
|
+
n_classes = unique_classes.shape[0]
|
|
37
|
+
|
|
38
|
+
train_indices, val_indices, test_indices = [], [], []
|
|
39
|
+
|
|
40
|
+
for c in range(n_classes):
|
|
41
|
+
class_indices = np.where(y_indices == c)[0]
|
|
42
|
+
np.random.shuffle(class_indices)
|
|
43
|
+
|
|
44
|
+
n_train = int(np.floor(train_ratio * len(class_indices))) or 1
|
|
45
|
+
n_val = (
|
|
46
|
+
len(class_indices) - n_train
|
|
47
|
+
if test_ratio == 0.0
|
|
48
|
+
else int(np.floor(val_ratio * len(class_indices))) or 1
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
train_indices.extend(class_indices[:n_train])
|
|
52
|
+
val_indices.extend(class_indices[n_train : n_train + n_val])
|
|
53
|
+
if test_ratio > 0.0:
|
|
54
|
+
test_indices.extend(class_indices[n_train + n_val :])
|
|
55
|
+
|
|
56
|
+
return train_indices, val_indices, test_indices or None
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
|
-
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
|
5
|
+
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger, WandbLogger
|
|
6
6
|
|
|
7
7
|
"""Supported loggers."""
|
|
8
|
-
ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger]
|
|
8
|
+
ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger, WandbLogger]
|
eva/core/loggers/log/__init__.py
CHANGED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Image log functionality."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.loggers import loggers
|
|
8
|
+
from eva.core.loggers.log import utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@functools.singledispatch
|
|
12
|
+
def log_image(
|
|
13
|
+
logger,
|
|
14
|
+
tag: str,
|
|
15
|
+
image: torch.Tensor,
|
|
16
|
+
step: int = 0,
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Adds an image to the logger.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
logger: The desired logger.
|
|
22
|
+
tag: The log tag.
|
|
23
|
+
image: The image tensor to log. It should have
|
|
24
|
+
the shape of (3,H,W) and (0,1) normalized.
|
|
25
|
+
step: The global step of the log.
|
|
26
|
+
"""
|
|
27
|
+
utils.raise_not_supported(logger, "image")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@log_image.register
|
|
31
|
+
def _(
|
|
32
|
+
loggers: list,
|
|
33
|
+
tag: str,
|
|
34
|
+
image: torch.Tensor,
|
|
35
|
+
step: int = 0,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Adds an image to a list of supported loggers."""
|
|
38
|
+
for logger in loggers:
|
|
39
|
+
log_image(
|
|
40
|
+
logger,
|
|
41
|
+
tag=tag,
|
|
42
|
+
image=image,
|
|
43
|
+
step=step,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@log_image.register
|
|
48
|
+
def _(
|
|
49
|
+
logger: loggers.TensorBoardLogger,
|
|
50
|
+
tag: str,
|
|
51
|
+
image: torch.Tensor,
|
|
52
|
+
step: int = 0,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Adds an image to a TensorBoard logger."""
|
|
55
|
+
logger.experiment.add_image(
|
|
56
|
+
tag=tag,
|
|
57
|
+
img_tensor=image,
|
|
58
|
+
global_step=step,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@log_image.register
|
|
63
|
+
def _(
|
|
64
|
+
logger: loggers.WandbLogger,
|
|
65
|
+
tag: str,
|
|
66
|
+
image: torch.Tensor,
|
|
67
|
+
caption: str | None = None,
|
|
68
|
+
step: int = 0,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Adds a list of images to a Wandb logger."""
|
|
71
|
+
logger.log_image(key=tag, images=[image.float()], step=step, caption=[caption])
|
|
@@ -51,6 +51,16 @@ def _(
|
|
|
51
51
|
)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
+
@log_parameters.register
|
|
55
|
+
def _(
|
|
56
|
+
logger: loggers_lib.WandbLogger,
|
|
57
|
+
tag: str,
|
|
58
|
+
parameters: Dict[str, Any],
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Adds parameters to a Wandb logger."""
|
|
61
|
+
logger.experiment.config.update(parameters)
|
|
62
|
+
|
|
63
|
+
|
|
54
64
|
def _yaml_to_markdown(data: Dict[str, Any]) -> str:
|
|
55
65
|
"""Casts yaml data to markdown.
|
|
56
66
|
|
eva/core/metrics/__init__.py
CHANGED
|
@@ -3,15 +3,19 @@
|
|
|
3
3
|
from eva.core.metrics.average_loss import AverageLoss
|
|
4
4
|
from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
|
|
5
5
|
from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
|
|
6
|
+
from eva.core.metrics.generalized_dice import GeneralizedDiceScore
|
|
7
|
+
from eva.core.metrics.mean_iou import MeanIoU
|
|
6
8
|
from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
|
|
7
9
|
|
|
8
10
|
__all__ = [
|
|
9
11
|
"AverageLoss",
|
|
10
12
|
"BinaryBalancedAccuracy",
|
|
13
|
+
"BinaryClassificationMetrics",
|
|
14
|
+
"MulticlassClassificationMetrics",
|
|
15
|
+
"GeneralizedDiceScore",
|
|
16
|
+
"MeanIoU",
|
|
11
17
|
"Metric",
|
|
12
18
|
"MetricCollection",
|
|
13
19
|
"MetricModule",
|
|
14
20
|
"MetricsSchema",
|
|
15
|
-
"MulticlassClassificationMetrics",
|
|
16
|
-
"BinaryClassificationMetrics",
|
|
17
21
|
]
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
"""Default metric collections API."""
|
|
2
2
|
|
|
3
|
-
from eva.core.metrics.defaults.classification
|
|
4
|
-
|
|
3
|
+
from eva.core.metrics.defaults.classification import (
|
|
4
|
+
BinaryClassificationMetrics,
|
|
5
|
+
MulticlassClassificationMetrics,
|
|
6
|
+
)
|
|
7
|
+
from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"MulticlassClassificationMetrics",
|
|
11
|
+
"BinaryClassificationMetrics",
|
|
12
|
+
"MulticlassSegmentationMetrics",
|
|
13
|
+
]
|
|
@@ -3,4 +3,4 @@
|
|
|
3
3
|
from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
|
|
4
4
|
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
|
|
5
5
|
|
|
6
|
-
__all__ = ["
|
|
6
|
+
__all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"]
|
|
@@ -17,15 +17,6 @@ class BinaryClassificationMetrics(structs.MetricCollection):
|
|
|
17
17
|
) -> None:
|
|
18
18
|
"""Initializes the binary classification metrics.
|
|
19
19
|
|
|
20
|
-
The metrics instantiated here are:
|
|
21
|
-
|
|
22
|
-
- BinaryAUROC
|
|
23
|
-
- BinaryAccuracy
|
|
24
|
-
- BinaryBalancedAccuracy
|
|
25
|
-
- BinaryF1Score
|
|
26
|
-
- BinaryPrecision
|
|
27
|
-
- BinaryRecall
|
|
28
|
-
|
|
29
20
|
Args:
|
|
30
21
|
threshold: Threshold for transforming probability to binary (0,1) predictions
|
|
31
22
|
ignore_index: Specifies a target value that is ignored and does not
|
|
@@ -20,14 +20,6 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
|
|
|
20
20
|
) -> None:
|
|
21
21
|
"""Initializes the multi-class classification metrics.
|
|
22
22
|
|
|
23
|
-
The metrics instantiated here are:
|
|
24
|
-
|
|
25
|
-
- MulticlassAccuracy
|
|
26
|
-
- MulticlassPrecision
|
|
27
|
-
- MulticlassRecall
|
|
28
|
-
- MulticlassF1Score
|
|
29
|
-
- MulticlassAUROC
|
|
30
|
-
|
|
31
23
|
Args:
|
|
32
24
|
num_classes: Integer specifying the number of classes.
|
|
33
25
|
average: Defines the reduction that is applied over labels.
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Default metric collection for multiclass semantic segmentation tasks."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics import generalized_dice, mean_iou, structs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
7
|
+
"""Default metrics for multi-class semantic segmentation tasks."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
num_classes: int,
|
|
12
|
+
include_background: bool = False,
|
|
13
|
+
ignore_index: int | None = None,
|
|
14
|
+
prefix: str | None = None,
|
|
15
|
+
postfix: str | None = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Initializes the multi-class semantic segmentation metrics.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
num_classes: Integer specifying the number of classes.
|
|
21
|
+
include_background: Whether to include the background class in the metrics computation.
|
|
22
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
23
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
24
|
+
prefix: A string to add before the keys in the output dictionary.
|
|
25
|
+
postfix: A string to add after the keys in the output dictionary.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(
|
|
28
|
+
metrics=[
|
|
29
|
+
generalized_dice.GeneralizedDiceScore(
|
|
30
|
+
num_classes=num_classes,
|
|
31
|
+
include_background=include_background,
|
|
32
|
+
weight_type="linear",
|
|
33
|
+
ignore_index=ignore_index,
|
|
34
|
+
),
|
|
35
|
+
mean_iou.MeanIoU(
|
|
36
|
+
num_classes=num_classes,
|
|
37
|
+
include_background=include_background,
|
|
38
|
+
ignore_index=ignore_index,
|
|
39
|
+
),
|
|
40
|
+
],
|
|
41
|
+
prefix=prefix,
|
|
42
|
+
postfix=postfix,
|
|
43
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Generalized Dice Score metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torchmetrics import segmentation
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
11
|
+
"""Defines the Generalized Dice Score.
|
|
12
|
+
|
|
13
|
+
It expands the `torchmetrics` class by including an `ignore_index`
|
|
14
|
+
functionality.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
num_classes: int,
|
|
20
|
+
include_background: bool = True,
|
|
21
|
+
weight_type: Literal["square", "simple", "linear"] = "linear",
|
|
22
|
+
ignore_index: int | None = None,
|
|
23
|
+
per_class: bool = False,
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the metric.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes: The number of classes in the segmentation problem.
|
|
30
|
+
include_background: Whether to include the background class in the computation
|
|
31
|
+
weight_type: The type of weight to apply to each class. Can be one of `"square"`,
|
|
32
|
+
`"simple"`, or `"linear"`.
|
|
33
|
+
input_format: What kind of input the function receives. Choose between ``"one-hot"``
|
|
34
|
+
for one-hot encoded tensors or ``"index"`` for index tensors.
|
|
35
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
36
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
37
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
38
|
+
the metric will compute the mean IoU over all classes.
|
|
39
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(
|
|
42
|
+
num_classes=num_classes,
|
|
43
|
+
include_background=include_background,
|
|
44
|
+
weight_type=weight_type,
|
|
45
|
+
per_class=per_class,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self.ignore_index = ignore_index
|
|
50
|
+
|
|
51
|
+
@override
|
|
52
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
53
|
+
if self.ignore_index is not None:
|
|
54
|
+
mask = target != self.ignore_index
|
|
55
|
+
mask = mask.all(dim=-1, keepdim=True)
|
|
56
|
+
preds = preds * mask
|
|
57
|
+
target = target * mask
|
|
58
|
+
|
|
59
|
+
super().update(preds=preds, target=target)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Mean Intersection over Union (mIoU) metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchmetrics
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MeanIoU(torchmetrics.Metric):
|
|
10
|
+
"""Computes Mean Intersection over Union (mIoU) for semantic segmentation.
|
|
11
|
+
|
|
12
|
+
Fixes the torchmetrics implementation
|
|
13
|
+
(issue https://github.com/Lightning-AI/torchmetrics/issues/2558)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
num_classes: int,
|
|
19
|
+
include_background: bool = True,
|
|
20
|
+
ignore_index: int | None = None,
|
|
21
|
+
per_class: bool = False,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the metric.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
num_classes: The number of classes in the segmentation problem.
|
|
28
|
+
include_background: Whether to include the background class in the computation
|
|
29
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
30
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
31
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
32
|
+
the metric will compute the mean IoU over all classes.
|
|
33
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
|
|
37
|
+
self.num_classes = num_classes
|
|
38
|
+
self.include_background = include_background
|
|
39
|
+
self.ignore_index = ignore_index
|
|
40
|
+
self.per_class = per_class
|
|
41
|
+
|
|
42
|
+
self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
|
43
|
+
self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
|
44
|
+
|
|
45
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
46
|
+
"""Update the state with the new data."""
|
|
47
|
+
intersection, union = _compute_intersection_and_union(
|
|
48
|
+
preds,
|
|
49
|
+
target,
|
|
50
|
+
num_classes=self.num_classes,
|
|
51
|
+
include_background=self.include_background,
|
|
52
|
+
ignore_index=self.ignore_index,
|
|
53
|
+
)
|
|
54
|
+
self.intersection += intersection.sum(0)
|
|
55
|
+
self.union += union.sum(0)
|
|
56
|
+
|
|
57
|
+
def compute(self) -> torch.Tensor:
|
|
58
|
+
"""Compute the final mean IoU score."""
|
|
59
|
+
iou_valid = torch.gt(self.union, 0)
|
|
60
|
+
iou = torch.where(
|
|
61
|
+
iou_valid,
|
|
62
|
+
torch.divide(self.intersection, self.union),
|
|
63
|
+
torch.nan,
|
|
64
|
+
)
|
|
65
|
+
if not self.per_class:
|
|
66
|
+
iou = torch.mean(iou[iou_valid])
|
|
67
|
+
return iou
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _compute_intersection_and_union(
|
|
71
|
+
preds: torch.Tensor,
|
|
72
|
+
target: torch.Tensor,
|
|
73
|
+
num_classes: int,
|
|
74
|
+
include_background: bool = False,
|
|
75
|
+
input_format: Literal["one-hot", "index"] = "index",
|
|
76
|
+
ignore_index: int | None = None,
|
|
77
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
78
|
+
"""Compute the intersection and union for semantic segmentation tasks.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
preds: Predicted tensor with shape (N, ...) where N is the batch size.
|
|
82
|
+
The shape can be (N, H, W) for 2D data or (N, D, H, W) for 3D data.
|
|
83
|
+
target: Ground truth tensor with the same shape as preds.
|
|
84
|
+
num_classes: Number of classes in the segmentation task.
|
|
85
|
+
include_background: Whether to include the background class in the computation.
|
|
86
|
+
input_format: Format of the input tensors.
|
|
87
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
88
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Two tensors representing the intersection and union for each class.
|
|
92
|
+
Shape of each tensor is (N, num_classes).
|
|
93
|
+
|
|
94
|
+
Note:
|
|
95
|
+
- If input_format is "index", the tensors are converted to one-hot encoding.
|
|
96
|
+
- If include_background is `False`, the background class
|
|
97
|
+
(assumed to be the first channel) is ignored in the computation.
|
|
98
|
+
"""
|
|
99
|
+
if ignore_index is not None:
|
|
100
|
+
mask = target != ignore_index
|
|
101
|
+
mask = mask.all(dim=-1, keepdim=True)
|
|
102
|
+
preds = preds * mask
|
|
103
|
+
target = target * mask
|
|
104
|
+
|
|
105
|
+
if input_format == "index":
|
|
106
|
+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
|
|
107
|
+
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
|
|
108
|
+
|
|
109
|
+
if not include_background:
|
|
110
|
+
preds[..., 0] = 0
|
|
111
|
+
target[..., 0] = 0
|
|
112
|
+
|
|
113
|
+
reduce_axis = list(range(1, preds.ndim - 1))
|
|
114
|
+
|
|
115
|
+
intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
|
|
116
|
+
target_sum = torch.sum(target, dim=reduce_axis)
|
|
117
|
+
pred_sum = torch.sum(preds, dim=reduce_axis)
|
|
118
|
+
union = target_sum + pred_sum - intersection
|
|
119
|
+
|
|
120
|
+
return intersection, union
|
|
@@ -44,4 +44,6 @@ class MetricsSchema:
|
|
|
44
44
|
if metrics is None or self.common is None:
|
|
45
45
|
return self.common or metrics
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore
|
|
48
|
+
common = self.common if isinstance(self.common, list) else [self.common]
|
|
49
|
+
return common + metrics # type: ignore
|