kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (168) hide show
  1. eva/core/callbacks/__init__.py +3 -2
  2. eva/core/callbacks/config.py +143 -0
  3. eva/core/callbacks/writers/__init__.py +6 -3
  4. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  5. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  6. eva/core/callbacks/writers/embeddings/base.py +192 -0
  7. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  8. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  9. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  10. eva/core/data/datasets/__init__.py +10 -2
  11. eva/core/data/datasets/classification/__init__.py +5 -2
  12. eva/core/data/datasets/classification/embeddings.py +15 -135
  13. eva/core/data/datasets/classification/multi_embeddings.py +110 -0
  14. eva/core/data/datasets/embeddings.py +167 -0
  15. eva/core/data/splitting/__init__.py +6 -0
  16. eva/core/data/splitting/random.py +41 -0
  17. eva/core/data/splitting/stratified.py +56 -0
  18. eva/core/data/transforms/__init__.py +3 -1
  19. eva/core/data/transforms/padding/__init__.py +5 -0
  20. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  21. eva/core/data/transforms/sampling/__init__.py +5 -0
  22. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  23. eva/core/loggers/__init__.py +7 -0
  24. eva/core/loggers/dummy.py +38 -0
  25. eva/core/loggers/experimental_loggers.py +8 -0
  26. eva/core/loggers/log/__init__.py +6 -0
  27. eva/core/loggers/log/image.py +71 -0
  28. eva/core/loggers/log/parameters.py +74 -0
  29. eva/core/loggers/log/utils.py +13 -0
  30. eva/core/loggers/loggers.py +6 -0
  31. eva/core/metrics/__init__.py +6 -2
  32. eva/core/metrics/defaults/__init__.py +10 -3
  33. eva/core/metrics/defaults/classification/__init__.py +1 -1
  34. eva/core/metrics/defaults/classification/binary.py +0 -9
  35. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  36. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  37. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  38. eva/core/metrics/generalized_dice.py +59 -0
  39. eva/core/metrics/mean_iou.py +120 -0
  40. eva/core/metrics/structs/schemas.py +3 -1
  41. eva/core/models/__init__.py +3 -1
  42. eva/core/models/modules/head.py +16 -15
  43. eva/core/models/modules/module.py +25 -1
  44. eva/core/models/modules/typings.py +14 -1
  45. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  46. eva/core/models/networks/__init__.py +1 -2
  47. eva/core/models/networks/mlp.py +2 -2
  48. eva/core/models/transforms/__init__.py +6 -0
  49. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  50. eva/core/models/transforms/extract_patch_features.py +47 -0
  51. eva/core/models/wrappers/__init__.py +13 -0
  52. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  53. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  54. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  55. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  56. eva/core/trainers/_recorder.py +69 -7
  57. eva/core/trainers/functional.py +23 -5
  58. eva/core/trainers/trainer.py +20 -6
  59. eva/core/utils/__init__.py +6 -0
  60. eva/core/utils/clone.py +27 -0
  61. eva/core/utils/memory.py +28 -0
  62. eva/core/utils/operations.py +26 -0
  63. eva/core/utils/parser.py +20 -0
  64. eva/vision/__init__.py +2 -2
  65. eva/vision/callbacks/__init__.py +5 -0
  66. eva/vision/callbacks/loggers/__init__.py +5 -0
  67. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  68. eva/vision/callbacks/loggers/batch/base.py +130 -0
  69. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  70. eva/vision/data/datasets/__init__.py +24 -4
  71. eva/vision/data/datasets/_utils.py +3 -3
  72. eva/vision/data/datasets/_validators.py +15 -2
  73. eva/vision/data/datasets/classification/__init__.py +6 -2
  74. eva/vision/data/datasets/classification/bach.py +10 -15
  75. eva/vision/data/datasets/classification/base.py +17 -24
  76. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  77. eva/vision/data/datasets/classification/crc.py +10 -15
  78. eva/vision/data/datasets/classification/mhist.py +10 -15
  79. eva/vision/data/datasets/classification/panda.py +184 -0
  80. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  81. eva/vision/data/datasets/classification/wsi.py +105 -0
  82. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  83. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  84. eva/vision/data/datasets/segmentation/base.py +31 -47
  85. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  86. eva/vision/data/datasets/segmentation/consep.py +156 -0
  87. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  88. eva/vision/data/datasets/segmentation/lits.py +178 -0
  89. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
  91. eva/vision/data/datasets/wsi.py +187 -0
  92. eva/vision/data/transforms/__init__.py +3 -2
  93. eva/vision/data/transforms/common/__init__.py +2 -1
  94. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  96. eva/vision/data/transforms/normalization/__init__.py +6 -0
  97. eva/vision/data/transforms/normalization/clamp.py +43 -0
  98. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  99. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  100. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  101. eva/vision/data/wsi/__init__.py +16 -0
  102. eva/vision/data/wsi/backends/__init__.py +69 -0
  103. eva/vision/data/wsi/backends/base.py +115 -0
  104. eva/vision/data/wsi/backends/openslide.py +73 -0
  105. eva/vision/data/wsi/backends/pil.py +52 -0
  106. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  107. eva/vision/data/wsi/patching/__init__.py +6 -0
  108. eva/vision/data/wsi/patching/coordinates.py +98 -0
  109. eva/vision/data/wsi/patching/mask.py +123 -0
  110. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  111. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  112. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  113. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  114. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  115. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  116. eva/vision/losses/__init__.py +5 -0
  117. eva/vision/losses/dice.py +40 -0
  118. eva/vision/models/__init__.py +4 -2
  119. eva/vision/models/modules/__init__.py +5 -0
  120. eva/vision/models/modules/semantic_segmentation.py +161 -0
  121. eva/vision/models/networks/__init__.py +1 -2
  122. eva/vision/models/networks/backbones/__init__.py +6 -0
  123. eva/vision/models/networks/backbones/_utils.py +39 -0
  124. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  125. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  126. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  127. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  128. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  129. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  130. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  131. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  132. eva/vision/models/networks/backbones/registry.py +47 -0
  133. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  134. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  135. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  136. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  137. eva/vision/models/networks/decoders/__init__.py +6 -0
  138. eva/vision/models/networks/decoders/decoder.py +7 -0
  139. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  140. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  141. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  142. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  143. eva/vision/models/wrappers/__init__.py +6 -0
  144. eva/vision/models/wrappers/from_registry.py +48 -0
  145. eva/vision/models/wrappers/from_timm.py +68 -0
  146. eva/vision/utils/colormap.py +77 -0
  147. eva/vision/utils/convert.py +67 -0
  148. eva/vision/utils/io/__init__.py +10 -4
  149. eva/vision/utils/io/image.py +21 -2
  150. eva/vision/utils/io/mat.py +36 -0
  151. eva/vision/utils/io/nifti.py +40 -15
  152. eva/vision/utils/io/text.py +10 -3
  153. kaiko_eva-0.1.0.dist-info/METADATA +553 -0
  154. kaiko_eva-0.1.0.dist-info/RECORD +205 -0
  155. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
  156. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
  157. eva/core/callbacks/writers/embeddings.py +0 -169
  158. eva/core/callbacks/writers/typings.py +0 -23
  159. eva/core/models/networks/transforms/__init__.py +0 -5
  160. eva/core/models/networks/wrappers/__init__.py +0 -8
  161. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  162. eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
  163. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  164. eva/vision/models/networks/postprocesses/cls.py +0 -25
  165. kaiko_eva-0.0.1.dist-info/METADATA +0 -405
  166. kaiko_eva-0.0.1.dist-info/RECORD +0 -110
  167. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  168. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,38 @@
1
+ """Dummy logger class."""
2
+
3
+ import lightning.pytorch.loggers.logger
4
+
5
+
6
+ class DummyLogger(lightning.pytorch.loggers.logger.DummyLogger):
7
+ """Dummy logger class.
8
+
9
+ This logger is currently used as a placeholder when saving results
10
+ to remote storage, as common lightning loggers do not work
11
+ with azure blob storage:
12
+
13
+ <https://github.com/Lightning-AI/pytorch-lightning/issues/18861>
14
+ <https://github.com/Lightning-AI/pytorch-lightning/issues/19736>
15
+
16
+ Simply disabling the loggers when pointing to remote storage doesn't work
17
+ because callbacks such as LearningRateMonitor or ModelCheckpoint require a
18
+ logger to be present.
19
+ """
20
+
21
+ def __init__(self, save_dir: str) -> None:
22
+ """Initializes the logger.
23
+
24
+ Args:
25
+ save_dir: The save directory (this logger does not save anything,
26
+ but callbacks might use this path to save their outputs).
27
+ """
28
+ super().__init__()
29
+ self._save_dir = save_dir
30
+
31
+ @property
32
+ def save_dir(self) -> str:
33
+ """Returns the save directory."""
34
+ return self._save_dir
35
+
36
+ def __deepcopy__(self, memo=None):
37
+ """Override of the deepcopy method."""
38
+ return self
@@ -0,0 +1,8 @@
1
+ """Experiment loggers."""
2
+
3
+ from typing import Union
4
+
5
+ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger, WandbLogger
6
+
7
+ """Supported loggers."""
8
+ ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger, WandbLogger]
@@ -0,0 +1,6 @@
1
+ """Experiment loggers operations."""
2
+
3
+ from eva.core.loggers.log.image import log_image
4
+ from eva.core.loggers.log.parameters import log_parameters
5
+
6
+ __all__ = ["log_image", "log_parameters"]
@@ -0,0 +1,71 @@
1
+ """Image log functionality."""
2
+
3
+ import functools
4
+
5
+ import torch
6
+
7
+ from eva.core.loggers import loggers
8
+ from eva.core.loggers.log import utils
9
+
10
+
11
+ @functools.singledispatch
12
+ def log_image(
13
+ logger,
14
+ tag: str,
15
+ image: torch.Tensor,
16
+ step: int = 0,
17
+ ) -> None:
18
+ """Adds an image to the logger.
19
+
20
+ Args:
21
+ logger: The desired logger.
22
+ tag: The log tag.
23
+ image: The image tensor to log. It should have
24
+ the shape of (3,H,W) and (0,1) normalized.
25
+ step: The global step of the log.
26
+ """
27
+ utils.raise_not_supported(logger, "image")
28
+
29
+
30
+ @log_image.register
31
+ def _(
32
+ loggers: list,
33
+ tag: str,
34
+ image: torch.Tensor,
35
+ step: int = 0,
36
+ ) -> None:
37
+ """Adds an image to a list of supported loggers."""
38
+ for logger in loggers:
39
+ log_image(
40
+ logger,
41
+ tag=tag,
42
+ image=image,
43
+ step=step,
44
+ )
45
+
46
+
47
+ @log_image.register
48
+ def _(
49
+ logger: loggers.TensorBoardLogger,
50
+ tag: str,
51
+ image: torch.Tensor,
52
+ step: int = 0,
53
+ ) -> None:
54
+ """Adds an image to a TensorBoard logger."""
55
+ logger.experiment.add_image(
56
+ tag=tag,
57
+ img_tensor=image,
58
+ global_step=step,
59
+ )
60
+
61
+
62
+ @log_image.register
63
+ def _(
64
+ logger: loggers.WandbLogger,
65
+ tag: str,
66
+ image: torch.Tensor,
67
+ caption: str | None = None,
68
+ step: int = 0,
69
+ ) -> None:
70
+ """Adds a list of images to a Wandb logger."""
71
+ logger.log_image(key=tag, images=[image.float()], step=step, caption=[caption])
@@ -0,0 +1,74 @@
1
+ """Text log functionality."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ import yaml
7
+
8
+ from eva.core.loggers import experimental_loggers as loggers_lib
9
+ from eva.core.loggers.log import utils
10
+
11
+
12
+ @functools.singledispatch
13
+ def log_parameters(
14
+ logger,
15
+ tag: str,
16
+ parameters: Dict[str, Any],
17
+ ) -> None:
18
+ """Adds parameters to the logger.
19
+
20
+ Args:
21
+ logger: The desired logger.
22
+ tag: The log tag.
23
+ parameters: The parameters to log.
24
+ """
25
+ utils.raise_not_supported(logger, "parameters")
26
+
27
+
28
+ @log_parameters.register
29
+ def _(
30
+ loggers: list,
31
+ tag: str,
32
+ parameters: Dict[str, Any],
33
+ ) -> None:
34
+ """Adds parameters to a list of supported loggers."""
35
+ for logger in loggers:
36
+ log_parameters(logger, tag=tag, parameters=parameters)
37
+
38
+
39
+ @log_parameters.register
40
+ def _(
41
+ logger: loggers_lib.TensorBoardLogger,
42
+ tag: str,
43
+ parameters: Dict[str, Any],
44
+ ) -> None:
45
+ """Adds parameters to a TensorBoard logger."""
46
+ as_markdown_text = _yaml_to_markdown(parameters)
47
+ logger.experiment.add_text(
48
+ tag=tag,
49
+ text_string=as_markdown_text,
50
+ global_step=0,
51
+ )
52
+
53
+
54
+ @log_parameters.register
55
+ def _(
56
+ logger: loggers_lib.WandbLogger,
57
+ tag: str,
58
+ parameters: Dict[str, Any],
59
+ ) -> None:
60
+ """Adds parameters to a Wandb logger."""
61
+ logger.experiment.config.update(parameters)
62
+
63
+
64
+ def _yaml_to_markdown(data: Dict[str, Any]) -> str:
65
+ """Casts yaml data to markdown.
66
+
67
+ Args:
68
+ data: The yaml data.
69
+
70
+ Returns:
71
+ A string markdown friendly formatted.
72
+ """
73
+ text = yaml.dump(data, sort_keys=False)
74
+ return f"```yaml\n{text}```"
@@ -0,0 +1,13 @@
1
+ """Logging related utilities."""
2
+
3
+ from loguru import logger as cli_logger
4
+
5
+ from eva.core.loggers import ExperimentalLoggers
6
+
7
+
8
+ def raise_not_supported(logger: ExperimentalLoggers, data_type: str) -> None:
9
+ """Raises a warning for not supported tasks from the given logger."""
10
+ print("\n")
11
+ cli_logger.debug(
12
+ f"Logger '{logger.__class__.__name__}' is not supported for " f"'{data_type}' data."
13
+ )
@@ -0,0 +1,6 @@
1
+ """Experimental loggers."""
2
+
3
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
4
+
5
+ Loggers = TensorBoardLogger | WandbLogger
6
+ """Supported loggers."""
@@ -3,15 +3,19 @@
3
3
  from eva.core.metrics.average_loss import AverageLoss
4
4
  from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
5
5
  from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
6
+ from eva.core.metrics.generalized_dice import GeneralizedDiceScore
7
+ from eva.core.metrics.mean_iou import MeanIoU
6
8
  from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
7
9
 
8
10
  __all__ = [
9
11
  "AverageLoss",
10
12
  "BinaryBalancedAccuracy",
13
+ "BinaryClassificationMetrics",
14
+ "MulticlassClassificationMetrics",
15
+ "GeneralizedDiceScore",
16
+ "MeanIoU",
11
17
  "Metric",
12
18
  "MetricCollection",
13
19
  "MetricModule",
14
20
  "MetricsSchema",
15
- "MulticlassClassificationMetrics",
16
- "BinaryClassificationMetrics",
17
21
  ]
@@ -1,6 +1,13 @@
1
1
  """Default metric collections API."""
2
2
 
3
- from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
- from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
3
+ from eva.core.metrics.defaults.classification import (
4
+ BinaryClassificationMetrics,
5
+ MulticlassClassificationMetrics,
6
+ )
7
+ from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics
5
8
 
6
- __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
9
+ __all__ = [
10
+ "MulticlassClassificationMetrics",
11
+ "BinaryClassificationMetrics",
12
+ "MulticlassSegmentationMetrics",
13
+ ]
@@ -3,4 +3,4 @@
3
3
  from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
4
  from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
5
5
 
6
- __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
6
+ __all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"]
@@ -17,15 +17,6 @@ class BinaryClassificationMetrics(structs.MetricCollection):
17
17
  ) -> None:
18
18
  """Initializes the binary classification metrics.
19
19
 
20
- The metrics instantiated here are:
21
-
22
- - BinaryAUROC
23
- - BinaryAccuracy
24
- - BinaryBalancedAccuracy
25
- - BinaryF1Score
26
- - BinaryPrecision
27
- - BinaryRecall
28
-
29
20
  Args:
30
21
  threshold: Threshold for transforming probability to binary (0,1) predictions
31
22
  ignore_index: Specifies a target value that is ignored and does not
@@ -20,14 +20,6 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
20
20
  ) -> None:
21
21
  """Initializes the multi-class classification metrics.
22
22
 
23
- The metrics instantiated here are:
24
-
25
- - MulticlassAccuracy
26
- - MulticlassPrecision
27
- - MulticlassRecall
28
- - MulticlassF1Score
29
- - MulticlassAUROC
30
-
31
23
  Args:
32
24
  num_classes: Integer specifying the number of classes.
33
25
  average: Defines the reduction that is applied over labels.
@@ -0,0 +1,5 @@
1
+ """Default segmentation metric collections API."""
2
+
3
+ from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
4
+
5
+ __all__ = ["MulticlassSegmentationMetrics"]
@@ -0,0 +1,43 @@
1
+ """Default metric collection for multiclass semantic segmentation tasks."""
2
+
3
+ from eva.core.metrics import generalized_dice, mean_iou, structs
4
+
5
+
6
+ class MulticlassSegmentationMetrics(structs.MetricCollection):
7
+ """Default metrics for multi-class semantic segmentation tasks."""
8
+
9
+ def __init__(
10
+ self,
11
+ num_classes: int,
12
+ include_background: bool = False,
13
+ ignore_index: int | None = None,
14
+ prefix: str | None = None,
15
+ postfix: str | None = None,
16
+ ) -> None:
17
+ """Initializes the multi-class semantic segmentation metrics.
18
+
19
+ Args:
20
+ num_classes: Integer specifying the number of classes.
21
+ include_background: Whether to include the background class in the metrics computation.
22
+ ignore_index: Integer specifying a target class to ignore. If given, this class
23
+ index does not contribute to the returned score, regardless of reduction method.
24
+ prefix: A string to add before the keys in the output dictionary.
25
+ postfix: A string to add after the keys in the output dictionary.
26
+ """
27
+ super().__init__(
28
+ metrics=[
29
+ generalized_dice.GeneralizedDiceScore(
30
+ num_classes=num_classes,
31
+ include_background=include_background,
32
+ weight_type="linear",
33
+ ignore_index=ignore_index,
34
+ ),
35
+ mean_iou.MeanIoU(
36
+ num_classes=num_classes,
37
+ include_background=include_background,
38
+ ignore_index=ignore_index,
39
+ ),
40
+ ],
41
+ prefix=prefix,
42
+ postfix=postfix,
43
+ )
@@ -0,0 +1,59 @@
1
+ """Generalized Dice Score metric for semantic segmentation."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+ from torchmetrics import segmentation
7
+ from typing_extensions import override
8
+
9
+
10
+ class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
11
+ """Defines the Generalized Dice Score.
12
+
13
+ It expands the `torchmetrics` class by including an `ignore_index`
14
+ functionality.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int,
20
+ include_background: bool = True,
21
+ weight_type: Literal["square", "simple", "linear"] = "linear",
22
+ ignore_index: int | None = None,
23
+ per_class: bool = False,
24
+ **kwargs: Any,
25
+ ) -> None:
26
+ """Initializes the metric.
27
+
28
+ Args:
29
+ num_classes: The number of classes in the segmentation problem.
30
+ include_background: Whether to include the background class in the computation
31
+ weight_type: The type of weight to apply to each class. Can be one of `"square"`,
32
+ `"simple"`, or `"linear"`.
33
+ input_format: What kind of input the function receives. Choose between ``"one-hot"``
34
+ for one-hot encoded tensors or ``"index"`` for index tensors.
35
+ ignore_index: Integer specifying a target class to ignore. If given, this class
36
+ index does not contribute to the returned score, regardless of reduction method.
37
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
38
+ the metric will compute the mean IoU over all classes.
39
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
40
+ """
41
+ super().__init__(
42
+ num_classes=num_classes,
43
+ include_background=include_background,
44
+ weight_type=weight_type,
45
+ per_class=per_class,
46
+ **kwargs,
47
+ )
48
+
49
+ self.ignore_index = ignore_index
50
+
51
+ @override
52
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
53
+ if self.ignore_index is not None:
54
+ mask = target != self.ignore_index
55
+ mask = mask.all(dim=-1, keepdim=True)
56
+ preds = preds * mask
57
+ target = target * mask
58
+
59
+ super().update(preds=preds, target=target)
@@ -0,0 +1,120 @@
1
+ """Mean Intersection over Union (mIoU) metric for semantic segmentation."""
2
+
3
+ from typing import Any, Literal, Tuple
4
+
5
+ import torch
6
+ import torchmetrics
7
+
8
+
9
+ class MeanIoU(torchmetrics.Metric):
10
+ """Computes Mean Intersection over Union (mIoU) for semantic segmentation.
11
+
12
+ Fixes the torchmetrics implementation
13
+ (issue https://github.com/Lightning-AI/torchmetrics/issues/2558)
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ num_classes: int,
19
+ include_background: bool = True,
20
+ ignore_index: int | None = None,
21
+ per_class: bool = False,
22
+ **kwargs: Any,
23
+ ) -> None:
24
+ """Initializes the metric.
25
+
26
+ Args:
27
+ num_classes: The number of classes in the segmentation problem.
28
+ include_background: Whether to include the background class in the computation
29
+ ignore_index: Integer specifying a target class to ignore. If given, this class
30
+ index does not contribute to the returned score, regardless of reduction method.
31
+ per_class: Whether to compute the IoU for each class separately. If set to ``False``,
32
+ the metric will compute the mean IoU over all classes.
33
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
34
+ """
35
+ super().__init__(**kwargs)
36
+
37
+ self.num_classes = num_classes
38
+ self.include_background = include_background
39
+ self.ignore_index = ignore_index
40
+ self.per_class = per_class
41
+
42
+ self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
43
+ self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
44
+
45
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
46
+ """Update the state with the new data."""
47
+ intersection, union = _compute_intersection_and_union(
48
+ preds,
49
+ target,
50
+ num_classes=self.num_classes,
51
+ include_background=self.include_background,
52
+ ignore_index=self.ignore_index,
53
+ )
54
+ self.intersection += intersection.sum(0)
55
+ self.union += union.sum(0)
56
+
57
+ def compute(self) -> torch.Tensor:
58
+ """Compute the final mean IoU score."""
59
+ iou_valid = torch.gt(self.union, 0)
60
+ iou = torch.where(
61
+ iou_valid,
62
+ torch.divide(self.intersection, self.union),
63
+ torch.nan,
64
+ )
65
+ if not self.per_class:
66
+ iou = torch.mean(iou[iou_valid])
67
+ return iou
68
+
69
+
70
+ def _compute_intersection_and_union(
71
+ preds: torch.Tensor,
72
+ target: torch.Tensor,
73
+ num_classes: int,
74
+ include_background: bool = False,
75
+ input_format: Literal["one-hot", "index"] = "index",
76
+ ignore_index: int | None = None,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ """Compute the intersection and union for semantic segmentation tasks.
79
+
80
+ Args:
81
+ preds: Predicted tensor with shape (N, ...) where N is the batch size.
82
+ The shape can be (N, H, W) for 2D data or (N, D, H, W) for 3D data.
83
+ target: Ground truth tensor with the same shape as preds.
84
+ num_classes: Number of classes in the segmentation task.
85
+ include_background: Whether to include the background class in the computation.
86
+ input_format: Format of the input tensors.
87
+ ignore_index: Integer specifying a target class to ignore. If given, this class
88
+ index does not contribute to the returned score, regardless of reduction method.
89
+
90
+ Returns:
91
+ Two tensors representing the intersection and union for each class.
92
+ Shape of each tensor is (N, num_classes).
93
+
94
+ Note:
95
+ - If input_format is "index", the tensors are converted to one-hot encoding.
96
+ - If include_background is `False`, the background class
97
+ (assumed to be the first channel) is ignored in the computation.
98
+ """
99
+ if ignore_index is not None:
100
+ mask = target != ignore_index
101
+ mask = mask.all(dim=-1, keepdim=True)
102
+ preds = preds * mask
103
+ target = target * mask
104
+
105
+ if input_format == "index":
106
+ preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
107
+ target = torch.nn.functional.one_hot(target, num_classes=num_classes)
108
+
109
+ if not include_background:
110
+ preds[..., 0] = 0
111
+ target[..., 0] = 0
112
+
113
+ reduce_axis = list(range(1, preds.ndim - 1))
114
+
115
+ intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
116
+ target_sum = torch.sum(target, dim=reduce_axis)
117
+ pred_sum = torch.sum(preds, dim=reduce_axis)
118
+ union = target_sum + pred_sum - intersection
119
+
120
+ return intersection, union
@@ -44,4 +44,6 @@ class MetricsSchema:
44
44
  if metrics is None or self.common is None:
45
45
  return self.common or metrics
46
46
 
47
- return [self.common, metrics] # type: ignore
47
+ metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore
48
+ common = self.common if isinstance(self.common, list) else [self.common]
49
+ return common + metrics # type: ignore
@@ -1,12 +1,14 @@
1
1
  """Models API."""
2
2
 
3
3
  from eva.core.models.modules import HeadModule, InferenceModule
4
- from eva.core.models.networks import MLP, HuggingFaceModel, ModelFromFunction, ONNXModel
4
+ from eva.core.models.networks import MLP
5
+ from eva.core.models.wrappers import BaseModel, HuggingFaceModel, ModelFromFunction, ONNXModel
5
6
 
6
7
  __all__ = [
7
8
  "HeadModule",
8
9
  "InferenceModule",
9
10
  "MLP",
11
+ "BaseModel",
10
12
  "HuggingFaceModel",
11
13
  "ModelFromFunction",
12
14
  "ONNXModel",
@@ -1,11 +1,11 @@
1
1
  """"Neural Network Head Module."""
2
2
 
3
- from typing import Any, Callable
3
+ from typing import Any, Callable, Dict
4
4
 
5
5
  import torch
6
6
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
7
  from lightning.pytorch.utilities.types import STEP_OUTPUT
8
- from torch import optim
8
+ from torch import nn, optim
9
9
  from torch.optim import lr_scheduler
10
10
  from typing_extensions import override
11
11
 
@@ -13,6 +13,7 @@ from eva.core.metrics import structs as metrics_lib
13
13
  from eva.core.models.modules import module
14
14
  from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
15
15
  from eva.core.models.modules.utils import batch_postprocess, grad
16
+ from eva.core.utils import parser
16
17
 
17
18
 
18
19
  class HeadModule(module.ModelModule):
@@ -24,7 +25,7 @@ class HeadModule(module.ModelModule):
24
25
 
25
26
  def __init__(
26
27
  self,
27
- head: MODEL_TYPE,
28
+ head: Dict[str, Any] | MODEL_TYPE,
28
29
  criterion: Callable[..., torch.Tensor],
29
30
  backbone: MODEL_TYPE | None = None,
30
31
  optimizer: OptimizerCallable = optim.Adam,
@@ -36,6 +37,8 @@ class HeadModule(module.ModelModule):
36
37
 
37
38
  Args:
38
39
  head: The neural network that would be trained on the features.
40
+ If its a dictionary, it will be parsed to an object during the
41
+ `configure_model` step.
39
42
  criterion: The loss function to use.
40
43
  backbone: The feature extractor. If `None`, it will be expected
41
44
  that the input batch returns the features directly.
@@ -48,15 +51,23 @@ class HeadModule(module.ModelModule):
48
51
  """
49
52
  super().__init__(metrics=metrics, postprocess=postprocess)
50
53
 
51
- self.head = head
54
+ self.head = head # type: ignore
52
55
  self.criterion = criterion
53
56
  self.backbone = backbone
54
57
  self.optimizer = optimizer
55
58
  self.lr_scheduler = lr_scheduler
56
59
 
60
+ @override
61
+ def configure_model(self) -> Any:
62
+ if self.backbone is not None:
63
+ grad.deactivate_requires_grad(self.backbone)
64
+
65
+ if isinstance(self.head, dict):
66
+ self.head: MODEL_TYPE = parser.parse_object(self.head, expected_type=nn.Module)
67
+
57
68
  @override
58
69
  def configure_optimizers(self) -> Any:
59
- parameters = list(self.head.parameters())
70
+ parameters = self.head.parameters()
60
71
  optimizer = self.optimizer(parameters)
61
72
  lr_scheduler = self.lr_scheduler(optimizer)
62
73
  return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
@@ -66,11 +77,6 @@ class HeadModule(module.ModelModule):
66
77
  features = tensor if self.backbone is None else self.backbone(tensor)
67
78
  return self.head(features).squeeze(-1)
68
79
 
69
- @override
70
- def on_fit_start(self) -> None:
71
- if self.backbone is not None:
72
- grad.deactivate_requires_grad(self.backbone)
73
-
74
80
  @override
75
81
  def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
76
82
  return self._batch_step(batch)
@@ -88,11 +94,6 @@ class HeadModule(module.ModelModule):
88
94
  tensor = INPUT_BATCH(*batch).data
89
95
  return tensor if self.backbone is None else self.backbone(tensor)
90
96
 
91
- @override
92
- def on_fit_end(self) -> None:
93
- if self.backbone is not None:
94
- grad.activate_requires_grad(self.backbone)
95
-
96
97
  def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
97
98
  """Performs a model forward step and calculates the loss.
98
99