kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,10 @@
1
1
  """Base dataset class for Embeddings."""
2
2
 
3
3
  import abc
4
+ import multiprocessing
4
5
  import os
5
- from typing import Callable, Dict, Literal, Tuple
6
+ from typing import Callable, Dict, Generic, Literal, Tuple, TypeVar
6
7
 
7
- import numpy as np
8
8
  import pandas as pd
9
9
  import torch
10
10
  from typing_extensions import override
@@ -12,16 +12,20 @@ from typing_extensions import override
12
12
  from eva.core.data.datasets import base
13
13
  from eva.core.utils import io
14
14
 
15
+ TargetType = TypeVar("TargetType")
16
+ """The target data type."""
17
+
18
+
15
19
  default_column_mapping: Dict[str, str] = {
16
20
  "path": "embeddings",
17
21
  "target": "target",
18
22
  "split": "split",
19
- "multi_id": "slide_id",
23
+ "multi_id": "wsi_id",
20
24
  }
21
25
  """The default column mapping of the variables to the manifest columns."""
22
26
 
23
27
 
24
- class EmbeddingsDataset(base.Dataset):
28
+ class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
25
29
  """Abstract base class for embedding datasets."""
26
30
 
27
31
  def __init__(
@@ -62,31 +66,7 @@ class EmbeddingsDataset(base.Dataset):
62
66
 
63
67
  self._data: pd.DataFrame
64
68
 
65
- @abc.abstractmethod
66
- def _load_embeddings(self, index: int) -> torch.Tensor:
67
- """Returns the `index`'th embedding sample.
68
-
69
- Args:
70
- index: The index of the data sample to load.
71
-
72
- Returns:
73
- The embedding sample as a tensor.
74
- """
75
-
76
- @abc.abstractmethod
77
- def _load_target(self, index: int) -> np.ndarray:
78
- """Returns the `index`'th target sample.
79
-
80
- Args:
81
- index: The index of the data sample to load.
82
-
83
- Returns:
84
- The sample target as an array.
85
- """
86
-
87
- @abc.abstractmethod
88
- def __len__(self) -> int:
89
- """Returns the total length of the data."""
69
+ self._set_multiprocessing_start_method()
90
70
 
91
71
  def filename(self, index: int) -> str:
92
72
  """Returns the filename of the `index`'th data sample.
@@ -105,7 +85,11 @@ class EmbeddingsDataset(base.Dataset):
105
85
  def setup(self):
106
86
  self._data = self._load_manifest()
107
87
 
108
- def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
88
+ @abc.abstractmethod
89
+ def __len__(self) -> int:
90
+ """Returns the total length of the data."""
91
+
92
+ def __getitem__(self, index) -> Tuple[torch.Tensor, TargetType]:
109
93
  """Returns the `index`'th data sample.
110
94
 
111
95
  Args:
@@ -118,6 +102,28 @@ class EmbeddingsDataset(base.Dataset):
118
102
  target = self._load_target(index)
119
103
  return self._apply_transforms(embeddings, target)
120
104
 
105
+ @abc.abstractmethod
106
+ def _load_embeddings(self, index: int) -> torch.Tensor:
107
+ """Returns the `index`'th embedding sample.
108
+
109
+ Args:
110
+ index: The index of the data sample to load.
111
+
112
+ Returns:
113
+ The embedding sample as a tensor.
114
+ """
115
+
116
+ @abc.abstractmethod
117
+ def _load_target(self, index: int) -> TargetType:
118
+ """Returns the `index`'th target sample.
119
+
120
+ Args:
121
+ index: The index of the data sample to load.
122
+
123
+ Returns:
124
+ The sample target as an array.
125
+ """
126
+
121
127
  def _load_manifest(self) -> pd.DataFrame:
122
128
  """Loads manifest file and filters the data based on the split column.
123
129
 
@@ -132,8 +138,8 @@ class EmbeddingsDataset(base.Dataset):
132
138
  return data
133
139
 
134
140
  def _apply_transforms(
135
- self, embeddings: torch.Tensor, target: np.ndarray
136
- ) -> Tuple[torch.Tensor, np.ndarray]:
141
+ self, embeddings: torch.Tensor, target: TargetType
142
+ ) -> Tuple[torch.Tensor, TargetType]:
137
143
  """Applies the transforms to the provided data and returns them.
138
144
 
139
145
  Args:
@@ -150,3 +156,12 @@ class EmbeddingsDataset(base.Dataset):
150
156
  target = self._target_transforms(target)
151
157
 
152
158
  return embeddings, target
159
+
160
+ def _set_multiprocessing_start_method(self):
161
+ """Sets the multiprocessing start method to spawn.
162
+
163
+ If the start method is not set explicitly, the torch data loaders will
164
+ use the OS default method, which for some unix systems is `fork` and
165
+ can lead to runtime issues such as deadlocks in this context.
166
+ """
167
+ multiprocessing.set_start_method("spawn", force=True)
@@ -0,0 +1,6 @@
1
+ """Dataset splitting API."""
2
+
3
+ from eva.core.data.splitting.random import random_split
4
+ from eva.core.data.splitting.stratified import stratified_split
5
+
6
+ __all__ = ["random_split", "stratified_split"]
@@ -0,0 +1,41 @@
1
+ """Functions for random splitting."""
2
+
3
+ from typing import Any, List, Sequence, Tuple
4
+
5
+ import numpy as np
6
+
7
+
8
+ def random_split(
9
+ samples: Sequence[Any],
10
+ train_ratio: float,
11
+ val_ratio: float,
12
+ test_ratio: float = 0.0,
13
+ seed: int = 42,
14
+ ) -> Tuple[List[int], List[int], List[int] | None]:
15
+ """Splits the samples into random train, validation, and test (optional) sets.
16
+
17
+ Args:
18
+ samples: The samples to split.
19
+ train_ratio: The ratio of the training set.
20
+ val_ratio: The ratio of the validation set.
21
+ test_ratio: The ratio of the test set (optional).
22
+ seed: The seed for reproducibility.
23
+
24
+ Returns:
25
+ The indices of the train, validation, and test sets as lists.
26
+ """
27
+ if train_ratio + val_ratio + (test_ratio or 0) != 1:
28
+ raise ValueError("The sum of the ratios must be equal to 1.")
29
+
30
+ np.random.seed(seed)
31
+ n_samples = len(samples)
32
+ indices = np.random.permutation(n_samples)
33
+
34
+ n_train = int(np.floor(train_ratio * n_samples))
35
+ n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
36
+
37
+ train_indices = list(indices[:n_train])
38
+ val_indices = list(indices[n_train : n_train + n_val])
39
+ test_indices = list(indices[n_train + n_val :]) if test_ratio > 0.0 else None
40
+
41
+ return train_indices, val_indices, test_indices
@@ -0,0 +1,56 @@
1
+ """Functions for stratified splitting."""
2
+
3
+ from typing import Any, List, Sequence, Tuple
4
+
5
+ import numpy as np
6
+
7
+
8
+ def stratified_split(
9
+ samples: Sequence[Any],
10
+ targets: Sequence[Any],
11
+ train_ratio: float,
12
+ val_ratio: float,
13
+ test_ratio: float = 0.0,
14
+ seed: int = 42,
15
+ ) -> Tuple[List[int], List[int], List[int] | None]:
16
+ """Splits the samples into stratified train, validation, and test (optional) sets.
17
+
18
+ Args:
19
+ samples: The samples to split.
20
+ targets: The corresponding targets used for stratification.
21
+ train_ratio: The ratio of the training set.
22
+ val_ratio: The ratio of the validation set.
23
+ test_ratio: The ratio of the test set (optional).
24
+ seed: The seed for reproducibility.
25
+
26
+ Returns:
27
+ The indices of the train, validation, and test sets.
28
+ """
29
+ if len(samples) != len(targets):
30
+ raise ValueError("The number of samples and targets must be equal.")
31
+ if train_ratio + val_ratio + (test_ratio or 0) != 1:
32
+ raise ValueError("The sum of the ratios must be equal to 1.")
33
+
34
+ np.random.seed(seed)
35
+ unique_classes, y_indices = np.unique(targets, return_inverse=True)
36
+ n_classes = unique_classes.shape[0]
37
+
38
+ train_indices, val_indices, test_indices = [], [], []
39
+
40
+ for c in range(n_classes):
41
+ class_indices = np.where(y_indices == c)[0]
42
+ np.random.shuffle(class_indices)
43
+
44
+ n_train = int(np.floor(train_ratio * len(class_indices))) or 1
45
+ n_val = (
46
+ len(class_indices) - n_train
47
+ if test_ratio == 0.0
48
+ else int(np.floor(val_ratio * len(class_indices))) or 1
49
+ )
50
+
51
+ train_indices.extend(class_indices[:n_train])
52
+ val_indices.extend(class_indices[n_train : n_train + n_val])
53
+ if test_ratio > 0.0:
54
+ test_indices.extend(class_indices[n_train + n_val :])
55
+
56
+ return train_indices, val_indices, test_indices or None
@@ -2,7 +2,7 @@
2
2
 
3
3
  from typing import Union
4
4
 
5
- from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
5
+ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger, WandbLogger
6
6
 
7
7
  """Supported loggers."""
8
- ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger]
8
+ ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger, WandbLogger]
@@ -1,5 +1,6 @@
1
- """Experiment loggers actions."""
1
+ """Experiment loggers operations."""
2
2
 
3
+ from eva.core.loggers.log.image import log_image
3
4
  from eva.core.loggers.log.parameters import log_parameters
4
5
 
5
- __all__ = ["log_parameters"]
6
+ __all__ = ["log_image", "log_parameters"]
@@ -0,0 +1,71 @@
1
+ """Image log functionality."""
2
+
3
+ import functools
4
+
5
+ import torch
6
+
7
+ from eva.core.loggers import loggers
8
+ from eva.core.loggers.log import utils
9
+
10
+
11
+ @functools.singledispatch
12
+ def log_image(
13
+ logger,
14
+ tag: str,
15
+ image: torch.Tensor,
16
+ step: int = 0,
17
+ ) -> None:
18
+ """Adds an image to the logger.
19
+
20
+ Args:
21
+ logger: The desired logger.
22
+ tag: The log tag.
23
+ image: The image tensor to log. It should have
24
+ the shape of (3,H,W) and (0,1) normalized.
25
+ step: The global step of the log.
26
+ """
27
+ utils.raise_not_supported(logger, "image")
28
+
29
+
30
+ @log_image.register
31
+ def _(
32
+ loggers: list,
33
+ tag: str,
34
+ image: torch.Tensor,
35
+ step: int = 0,
36
+ ) -> None:
37
+ """Adds an image to a list of supported loggers."""
38
+ for logger in loggers:
39
+ log_image(
40
+ logger,
41
+ tag=tag,
42
+ image=image,
43
+ step=step,
44
+ )
45
+
46
+
47
+ @log_image.register
48
+ def _(
49
+ logger: loggers.TensorBoardLogger,
50
+ tag: str,
51
+ image: torch.Tensor,
52
+ step: int = 0,
53
+ ) -> None:
54
+ """Adds an image to a TensorBoard logger."""
55
+ logger.experiment.add_image(
56
+ tag=tag,
57
+ img_tensor=image,
58
+ global_step=step,
59
+ )
60
+
61
+
62
+ @log_image.register
63
+ def _(
64
+ logger: loggers.WandbLogger,
65
+ tag: str,
66
+ image: torch.Tensor,
67
+ caption: str | None = None,
68
+ step: int = 0,
69
+ ) -> None:
70
+ """Adds a list of images to a Wandb logger."""
71
+ logger.log_image(key=tag, images=[image.float()], step=step, caption=[caption])
@@ -51,6 +51,16 @@ def _(
51
51
  )
52
52
 
53
53
 
54
+ @log_parameters.register
55
+ def _(
56
+ logger: loggers_lib.WandbLogger,
57
+ tag: str,
58
+ parameters: Dict[str, Any],
59
+ ) -> None:
60
+ """Adds parameters to a Wandb logger."""
61
+ logger.experiment.config.update(parameters)
62
+
63
+
54
64
  def _yaml_to_markdown(data: Dict[str, Any]) -> str:
55
65
  """Casts yaml data to markdown.
56
66
 
@@ -0,0 +1,6 @@
1
+ """Experimental loggers."""
2
+
3
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
4
+
5
+ Loggers = TensorBoardLogger | WandbLogger
6
+ """Supported loggers."""
@@ -3,15 +3,19 @@
3
3
  from eva.core.metrics.average_loss import AverageLoss
4
4
  from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
5
5
  from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
6
+ from eva.core.metrics.generalized_dice import GeneralizedDiceScore
7
+ from eva.core.metrics.mean_iou import MeanIoU
6
8
  from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
7
9
 
8
10
  __all__ = [
9
11
  "AverageLoss",
10
12
  "BinaryBalancedAccuracy",
13
+ "BinaryClassificationMetrics",
14
+ "MulticlassClassificationMetrics",
15
+ "GeneralizedDiceScore",
16
+ "MeanIoU",
11
17
  "Metric",
12
18
  "MetricCollection",
13
19
  "MetricModule",
14
20
  "MetricsSchema",
15
- "MulticlassClassificationMetrics",
16
- "BinaryClassificationMetrics",
17
21
  ]
@@ -1,6 +1,13 @@
1
1
  """Default metric collections API."""
2
2
 
3
- from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
- from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
3
+ from eva.core.metrics.defaults.classification import (
4
+ BinaryClassificationMetrics,
5
+ MulticlassClassificationMetrics,
6
+ )
7
+ from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics
5
8
 
6
- __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
9
+ __all__ = [
10
+ "MulticlassClassificationMetrics",
11
+ "BinaryClassificationMetrics",
12
+ "MulticlassSegmentationMetrics",
13
+ ]
@@ -3,4 +3,4 @@
3
3
  from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
4
  from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
5
5
 
6
- __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
6
+ __all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"]
@@ -17,15 +17,6 @@ class BinaryClassificationMetrics(structs.MetricCollection):
17
17
  ) -> None:
18
18
  """Initializes the binary classification metrics.
19
19
 
20
- The metrics instantiated here are:
21
-
22
- - BinaryAUROC
23
- - BinaryAccuracy
24
- - BinaryBalancedAccuracy
25
- - BinaryF1Score
26
- - BinaryPrecision
27
- - BinaryRecall
28
-
29
20
  Args:
30
21
  threshold: Threshold for transforming probability to binary (0,1) predictions
31
22
  ignore_index: Specifies a target value that is ignored and does not
@@ -20,14 +20,6 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
20
20
  ) -> None:
21
21
  """Initializes the multi-class classification metrics.
22
22
 
23
- The metrics instantiated here are:
24
-
25
- - MulticlassAccuracy
26
- - MulticlassPrecision
27
- - MulticlassRecall
28
- - MulticlassF1Score
29
- - MulticlassAUROC
30
-
31
23
  Args:
32
24
  num_classes: Integer specifying the number of classes.
33
25
  average: Defines the reduction that is applied over labels.
@@ -0,0 +1,5 @@
1
+ """Default segmentation metric collections API."""
2
+
3
+ from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
4
+
5
+ __all__ = ["MulticlassSegmentationMetrics"]
@@ -0,0 +1,43 @@
1
+ """Default metric collection for multiclass semantic segmentation tasks."""
2
+
3
+ from eva.core.metrics import generalized_dice, mean_iou, structs
4
+
5
+
6
+ class MulticlassSegmentationMetrics(structs.MetricCollection):
7
+ """Default metrics for multi-class semantic segmentation tasks."""
8
+
9
+ def __init__(
10
+ self,
11
+ num_classes: int,
12
+ include_background: bool = False,
13
+ ignore_index: int | None = None,
14
+ prefix: str | None = None,
15
+ postfix: str | None = None,
16
+ ) -> None:
17
+ """Initializes the multi-class semantic segmentation metrics.
18
+
19
+ Args:
20
+ num_classes: Integer specifying the number of classes.
21
+ include_background: Whether to include the background class in the metrics computation.
22
+ ignore_index: Integer specifying a target class to ignore. If given, this class
23
+ index does not contribute to the returned score, regardless of reduction method.
24
+ prefix: A string to add before the keys in the output dictionary.
25
+ postfix: A string to add after the keys in the output dictionary.
26
+ """
27
+ super().__init__(
28
+ metrics=[
29
+ generalized_dice.GeneralizedDiceScore(
30
+ num_classes=num_classes,
31
+ include_background=include_background,
32
+ weight_type="linear",
33
+ ignore_index=ignore_index,
34
+ ),
35
+ mean_iou.MeanIoU(
36
+ num_classes=num_classes,
37
+ include_background=include_background,
38
+ ignore_index=ignore_index,
39
+ ),
40
+ ],
41
+ prefix=prefix,
42
+ postfix=postfix,
43
+ )
@@ -0,0 +1,59 @@
1
+ """Generalized Dice Score metric for semantic segmentation."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+ from torchmetrics import segmentation
7
+ from typing_extensions import override
8
+
9
+
10
+ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
11
+ """Defines the Generalized Dice Score.
12
+
13
+ It expands the `torchmetrics` class by including an `ignore_index`
14
+ functionality.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int,
20
+ include_background: bool = True,
21
+ weight_type: Literal["square", "simple", "linear"] = "linear",
22
+ ignore_index: int | None = None,
23
+ per_class: bool = False,
24
+ **kwargs: Any,
25
+ ) -> None:
26
+ """Initializes the metric.
27
+
28
+ Args:
29
+ num_classes: The number of classes in the segmentation problem.
30
+ include_background: Whether to include the background class in the computation
31
+ weight_type: The type of weight to apply to each class. Can be one of `"square"`,
32
+ `"simple"`, or `"linear"`.
33
+ input_format: What kind of input the function receives. Choose between ``"one-hot"``
34
+ for one-hot encoded tensors or ``"index"`` for index tensors.
35
+ ignore_index: Integer specifying a target class to ignore. If given, this class
36
+ index does not contribute to the returned score, regardless of reduction method.
37
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
38
+ the metric will compute the mean IoU over all classes.
39
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
40
+ """
41
+ super().__init__(
42
+ num_classes=num_classes,
43
+ include_background=include_background,
44
+ weight_type=weight_type,
45
+ per_class=per_class,
46
+ **kwargs,
47
+ )
48
+
49
+ self.ignore_index = ignore_index
50
+
51
+ @override
52
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
53
+ if self.ignore_index is not None:
54
+ mask = target != self.ignore_index
55
+ mask = mask.all(dim=-1, keepdim=True)
56
+ preds = preds * mask
57
+ target = target * mask
58
+
59
+ super().update(preds=preds, target=target)
@@ -0,0 +1,120 @@
1
+ """Mean Intersection over Union (mIoU) metric for semantic segmentation."""
2
+
3
+ from typing import Any, Literal, Tuple
4
+
5
+ import torch
6
+ import torchmetrics
7
+
8
+
9
+ class MeanIoU(torchmetrics.Metric):
10
+ """Computes Mean Intersection over Union (mIoU) for semantic segmentation.
11
+
12
+ Fixes the torchmetrics implementation
13
+ (issue https://github.com/Lightning-AI/torchmetrics/issues/2558)
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ num_classes: int,
19
+ include_background: bool = True,
20
+ ignore_index: int | None = None,
21
+ per_class: bool = False,
22
+ **kwargs: Any,
23
+ ) -> None:
24
+ """Initializes the metric.
25
+
26
+ Args:
27
+ num_classes: The number of classes in the segmentation problem.
28
+ include_background: Whether to include the background class in the computation
29
+ ignore_index: Integer specifying a target class to ignore. If given, this class
30
+ index does not contribute to the returned score, regardless of reduction method.
31
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
32
+ the metric will compute the mean IoU over all classes.
33
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
34
+ """
35
+ super().__init__(**kwargs)
36
+
37
+ self.num_classes = num_classes
38
+ self.include_background = include_background
39
+ self.ignore_index = ignore_index
40
+ self.per_class = per_class
41
+
42
+ self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
43
+ self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
44
+
45
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
46
+ """Update the state with the new data."""
47
+ intersection, union = _compute_intersection_and_union(
48
+ preds,
49
+ target,
50
+ num_classes=self.num_classes,
51
+ include_background=self.include_background,
52
+ ignore_index=self.ignore_index,
53
+ )
54
+ self.intersection += intersection.sum(0)
55
+ self.union += union.sum(0)
56
+
57
+ def compute(self) -> torch.Tensor:
58
+ """Compute the final mean IoU score."""
59
+ iou_valid = torch.gt(self.union, 0)
60
+ iou = torch.where(
61
+ iou_valid,
62
+ torch.divide(self.intersection, self.union),
63
+ torch.nan,
64
+ )
65
+ if not self.per_class:
66
+ iou = torch.mean(iou[iou_valid])
67
+ return iou
68
+
69
+
70
+ def _compute_intersection_and_union(
71
+ preds: torch.Tensor,
72
+ target: torch.Tensor,
73
+ num_classes: int,
74
+ include_background: bool = False,
75
+ input_format: Literal["one-hot", "index"] = "index",
76
+ ignore_index: int | None = None,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ """Compute the intersection and union for semantic segmentation tasks.
79
+
80
+ Args:
81
+ preds: Predicted tensor with shape (N, ...) where N is the batch size.
82
+ The shape can be (N, H, W) for 2D data or (N, D, H, W) for 3D data.
83
+ target: Ground truth tensor with the same shape as preds.
84
+ num_classes: Number of classes in the segmentation task.
85
+ include_background: Whether to include the background class in the computation.
86
+ input_format: Format of the input tensors.
87
+ ignore_index: Integer specifying a target class to ignore. If given, this class
88
+ index does not contribute to the returned score, regardless of reduction method.
89
+
90
+ Returns:
91
+ Two tensors representing the intersection and union for each class.
92
+ Shape of each tensor is (N, num_classes).
93
+
94
+ Note:
95
+ - If input_format is "index", the tensors are converted to one-hot encoding.
96
+ - If include_background is `False`, the background class
97
+ (assumed to be the first channel) is ignored in the computation.
98
+ """
99
+ if ignore_index is not None:
100
+ mask = target != ignore_index
101
+ mask = mask.all(dim=-1, keepdim=True)
102
+ preds = preds * mask
103
+ target = target * mask
104
+
105
+ if input_format == "index":
106
+ preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
107
+ target = torch.nn.functional.one_hot(target, num_classes=num_classes)
108
+
109
+ if not include_background:
110
+ preds[..., 0] = 0
111
+ target[..., 0] = 0
112
+
113
+ reduce_axis = list(range(1, preds.ndim - 1))
114
+
115
+ intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
116
+ target_sum = torch.sum(target, dim=reduce_axis)
117
+ pred_sum = torch.sum(preds, dim=reduce_axis)
118
+ union = target_sum + pred_sum - intersection
119
+
120
+ return intersection, union
@@ -44,4 +44,6 @@ class MetricsSchema:
44
44
  if metrics is None or self.common is None:
45
45
  return self.common or metrics
46
46
 
47
- return [self.common, metrics] # type: ignore
47
+ metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore
48
+ common = self.common if isinstance(self.common, list) else [self.common]
49
+ return common + metrics # type: ignore