kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Dummy logger class."""
|
|
2
|
+
|
|
3
|
+
import lightning.pytorch.loggers.logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DummyLogger(lightning.pytorch.loggers.logger.DummyLogger):
|
|
7
|
+
"""Dummy logger class.
|
|
8
|
+
|
|
9
|
+
This logger is currently used as a placeholder when saving results
|
|
10
|
+
to remote storage, as common lightning loggers do not work
|
|
11
|
+
with azure blob storage:
|
|
12
|
+
|
|
13
|
+
<https://github.com/Lightning-AI/pytorch-lightning/issues/18861>
|
|
14
|
+
<https://github.com/Lightning-AI/pytorch-lightning/issues/19736>
|
|
15
|
+
|
|
16
|
+
Simply disabling the loggers when pointing to remote storage doesn't work
|
|
17
|
+
because callbacks such as LearningRateMonitor or ModelCheckpoint require a
|
|
18
|
+
logger to be present.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, save_dir: str) -> None:
|
|
22
|
+
"""Initializes the logger.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
save_dir: The save directory (this logger does not save anything,
|
|
26
|
+
but callbacks might use this path to save their outputs).
|
|
27
|
+
"""
|
|
28
|
+
super().__init__()
|
|
29
|
+
self._save_dir = save_dir
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def save_dir(self) -> str:
|
|
33
|
+
"""Returns the save directory."""
|
|
34
|
+
return self._save_dir
|
|
35
|
+
|
|
36
|
+
def __deepcopy__(self, memo=None):
|
|
37
|
+
"""Override of the deepcopy method."""
|
|
38
|
+
return self
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Image log functionality."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.loggers import loggers
|
|
8
|
+
from eva.core.loggers.log import utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@functools.singledispatch
|
|
12
|
+
def log_image(
|
|
13
|
+
logger,
|
|
14
|
+
tag: str,
|
|
15
|
+
image: torch.Tensor,
|
|
16
|
+
step: int = 0,
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Adds an image to the logger.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
logger: The desired logger.
|
|
22
|
+
tag: The log tag.
|
|
23
|
+
image: The image tensor to log. It should have
|
|
24
|
+
the shape of (3,H,W) and (0,1) normalized.
|
|
25
|
+
step: The global step of the log.
|
|
26
|
+
"""
|
|
27
|
+
utils.raise_not_supported(logger, "image")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@log_image.register
|
|
31
|
+
def _(
|
|
32
|
+
loggers: list,
|
|
33
|
+
tag: str,
|
|
34
|
+
image: torch.Tensor,
|
|
35
|
+
step: int = 0,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Adds an image to a list of supported loggers."""
|
|
38
|
+
for logger in loggers:
|
|
39
|
+
log_image(
|
|
40
|
+
logger,
|
|
41
|
+
tag=tag,
|
|
42
|
+
image=image,
|
|
43
|
+
step=step,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@log_image.register
|
|
48
|
+
def _(
|
|
49
|
+
logger: loggers.TensorBoardLogger,
|
|
50
|
+
tag: str,
|
|
51
|
+
image: torch.Tensor,
|
|
52
|
+
step: int = 0,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Adds an image to a TensorBoard logger."""
|
|
55
|
+
logger.experiment.add_image(
|
|
56
|
+
tag=tag,
|
|
57
|
+
img_tensor=image,
|
|
58
|
+
global_step=step,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@log_image.register
|
|
63
|
+
def _(
|
|
64
|
+
logger: loggers.WandbLogger,
|
|
65
|
+
tag: str,
|
|
66
|
+
image: torch.Tensor,
|
|
67
|
+
caption: str | None = None,
|
|
68
|
+
step: int = 0,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Adds a list of images to a Wandb logger."""
|
|
71
|
+
logger.log_image(key=tag, images=[image.float()], step=step, caption=[caption])
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Text log functionality."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from eva.core.loggers import experimental_loggers as loggers_lib
|
|
9
|
+
from eva.core.loggers.log import utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@functools.singledispatch
|
|
13
|
+
def log_parameters(
|
|
14
|
+
logger,
|
|
15
|
+
tag: str,
|
|
16
|
+
parameters: Dict[str, Any],
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Adds parameters to the logger.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
logger: The desired logger.
|
|
22
|
+
tag: The log tag.
|
|
23
|
+
parameters: The parameters to log.
|
|
24
|
+
"""
|
|
25
|
+
utils.raise_not_supported(logger, "parameters")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@log_parameters.register
|
|
29
|
+
def _(
|
|
30
|
+
loggers: list,
|
|
31
|
+
tag: str,
|
|
32
|
+
parameters: Dict[str, Any],
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Adds parameters to a list of supported loggers."""
|
|
35
|
+
for logger in loggers:
|
|
36
|
+
log_parameters(logger, tag=tag, parameters=parameters)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@log_parameters.register
|
|
40
|
+
def _(
|
|
41
|
+
logger: loggers_lib.TensorBoardLogger,
|
|
42
|
+
tag: str,
|
|
43
|
+
parameters: Dict[str, Any],
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Adds parameters to a TensorBoard logger."""
|
|
46
|
+
as_markdown_text = _yaml_to_markdown(parameters)
|
|
47
|
+
logger.experiment.add_text(
|
|
48
|
+
tag=tag,
|
|
49
|
+
text_string=as_markdown_text,
|
|
50
|
+
global_step=0,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@log_parameters.register
|
|
55
|
+
def _(
|
|
56
|
+
logger: loggers_lib.WandbLogger,
|
|
57
|
+
tag: str,
|
|
58
|
+
parameters: Dict[str, Any],
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Adds parameters to a Wandb logger."""
|
|
61
|
+
logger.experiment.config.update(parameters)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _yaml_to_markdown(data: Dict[str, Any]) -> str:
|
|
65
|
+
"""Casts yaml data to markdown.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
data: The yaml data.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A string markdown friendly formatted.
|
|
72
|
+
"""
|
|
73
|
+
text = yaml.dump(data, sort_keys=False)
|
|
74
|
+
return f"```yaml\n{text}```"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Logging related utilities."""
|
|
2
|
+
|
|
3
|
+
from loguru import logger as cli_logger
|
|
4
|
+
|
|
5
|
+
from eva.core.loggers import ExperimentalLoggers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def raise_not_supported(logger: ExperimentalLoggers, data_type: str) -> None:
|
|
9
|
+
"""Raises a warning for not supported tasks from the given logger."""
|
|
10
|
+
print("\n")
|
|
11
|
+
cli_logger.debug(
|
|
12
|
+
f"Logger '{logger.__class__.__name__}' is not supported for " f"'{data_type}' data."
|
|
13
|
+
)
|
eva/core/metrics/__init__.py
CHANGED
|
@@ -3,15 +3,19 @@
|
|
|
3
3
|
from eva.core.metrics.average_loss import AverageLoss
|
|
4
4
|
from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
|
|
5
5
|
from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
|
|
6
|
+
from eva.core.metrics.generalized_dice import GeneralizedDiceScore
|
|
7
|
+
from eva.core.metrics.mean_iou import MeanIoU
|
|
6
8
|
from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
|
|
7
9
|
|
|
8
10
|
__all__ = [
|
|
9
11
|
"AverageLoss",
|
|
10
12
|
"BinaryBalancedAccuracy",
|
|
13
|
+
"BinaryClassificationMetrics",
|
|
14
|
+
"MulticlassClassificationMetrics",
|
|
15
|
+
"GeneralizedDiceScore",
|
|
16
|
+
"MeanIoU",
|
|
11
17
|
"Metric",
|
|
12
18
|
"MetricCollection",
|
|
13
19
|
"MetricModule",
|
|
14
20
|
"MetricsSchema",
|
|
15
|
-
"MulticlassClassificationMetrics",
|
|
16
|
-
"BinaryClassificationMetrics",
|
|
17
21
|
]
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
"""Default metric collections API."""
|
|
2
2
|
|
|
3
|
-
from eva.core.metrics.defaults.classification
|
|
4
|
-
|
|
3
|
+
from eva.core.metrics.defaults.classification import (
|
|
4
|
+
BinaryClassificationMetrics,
|
|
5
|
+
MulticlassClassificationMetrics,
|
|
6
|
+
)
|
|
7
|
+
from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"MulticlassClassificationMetrics",
|
|
11
|
+
"BinaryClassificationMetrics",
|
|
12
|
+
"MulticlassSegmentationMetrics",
|
|
13
|
+
]
|
|
@@ -3,4 +3,4 @@
|
|
|
3
3
|
from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
|
|
4
4
|
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
|
|
5
5
|
|
|
6
|
-
__all__ = ["
|
|
6
|
+
__all__ = ["BinaryClassificationMetrics", "MulticlassClassificationMetrics"]
|
|
@@ -17,15 +17,6 @@ class BinaryClassificationMetrics(structs.MetricCollection):
|
|
|
17
17
|
) -> None:
|
|
18
18
|
"""Initializes the binary classification metrics.
|
|
19
19
|
|
|
20
|
-
The metrics instantiated here are:
|
|
21
|
-
|
|
22
|
-
- BinaryAUROC
|
|
23
|
-
- BinaryAccuracy
|
|
24
|
-
- BinaryBalancedAccuracy
|
|
25
|
-
- BinaryF1Score
|
|
26
|
-
- BinaryPrecision
|
|
27
|
-
- BinaryRecall
|
|
28
|
-
|
|
29
20
|
Args:
|
|
30
21
|
threshold: Threshold for transforming probability to binary (0,1) predictions
|
|
31
22
|
ignore_index: Specifies a target value that is ignored and does not
|
|
@@ -20,14 +20,6 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
|
|
|
20
20
|
) -> None:
|
|
21
21
|
"""Initializes the multi-class classification metrics.
|
|
22
22
|
|
|
23
|
-
The metrics instantiated here are:
|
|
24
|
-
|
|
25
|
-
- MulticlassAccuracy
|
|
26
|
-
- MulticlassPrecision
|
|
27
|
-
- MulticlassRecall
|
|
28
|
-
- MulticlassF1Score
|
|
29
|
-
- MulticlassAUROC
|
|
30
|
-
|
|
31
23
|
Args:
|
|
32
24
|
num_classes: Integer specifying the number of classes.
|
|
33
25
|
average: Defines the reduction that is applied over labels.
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Default metric collection for multiclass semantic segmentation tasks."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics import generalized_dice, mean_iou, structs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
7
|
+
"""Default metrics for multi-class semantic segmentation tasks."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
num_classes: int,
|
|
12
|
+
include_background: bool = False,
|
|
13
|
+
ignore_index: int | None = None,
|
|
14
|
+
prefix: str | None = None,
|
|
15
|
+
postfix: str | None = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Initializes the multi-class semantic segmentation metrics.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
num_classes: Integer specifying the number of classes.
|
|
21
|
+
include_background: Whether to include the background class in the metrics computation.
|
|
22
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
23
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
24
|
+
prefix: A string to add before the keys in the output dictionary.
|
|
25
|
+
postfix: A string to add after the keys in the output dictionary.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(
|
|
28
|
+
metrics=[
|
|
29
|
+
generalized_dice.GeneralizedDiceScore(
|
|
30
|
+
num_classes=num_classes,
|
|
31
|
+
include_background=include_background,
|
|
32
|
+
weight_type="linear",
|
|
33
|
+
ignore_index=ignore_index,
|
|
34
|
+
),
|
|
35
|
+
mean_iou.MeanIoU(
|
|
36
|
+
num_classes=num_classes,
|
|
37
|
+
include_background=include_background,
|
|
38
|
+
ignore_index=ignore_index,
|
|
39
|
+
),
|
|
40
|
+
],
|
|
41
|
+
prefix=prefix,
|
|
42
|
+
postfix=postfix,
|
|
43
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Generalized Dice Score metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torchmetrics import segmentation
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeneralizedDiceScore(segmentation.GeneralizedDiceScore):
|
|
11
|
+
"""Defines the Generalized Dice Score.
|
|
12
|
+
|
|
13
|
+
It expands the `torchmetrics` class by including an `ignore_index`
|
|
14
|
+
functionality.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
num_classes: int,
|
|
20
|
+
include_background: bool = True,
|
|
21
|
+
weight_type: Literal["square", "simple", "linear"] = "linear",
|
|
22
|
+
ignore_index: int | None = None,
|
|
23
|
+
per_class: bool = False,
|
|
24
|
+
**kwargs: Any,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the metric.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_classes: The number of classes in the segmentation problem.
|
|
30
|
+
include_background: Whether to include the background class in the computation
|
|
31
|
+
weight_type: The type of weight to apply to each class. Can be one of `"square"`,
|
|
32
|
+
`"simple"`, or `"linear"`.
|
|
33
|
+
input_format: What kind of input the function receives. Choose between ``"one-hot"``
|
|
34
|
+
for one-hot encoded tensors or ``"index"`` for index tensors.
|
|
35
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
36
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
37
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
38
|
+
the metric will compute the mean IoU over all classes.
|
|
39
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(
|
|
42
|
+
num_classes=num_classes,
|
|
43
|
+
include_background=include_background,
|
|
44
|
+
weight_type=weight_type,
|
|
45
|
+
per_class=per_class,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self.ignore_index = ignore_index
|
|
50
|
+
|
|
51
|
+
@override
|
|
52
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
53
|
+
if self.ignore_index is not None:
|
|
54
|
+
mask = target != self.ignore_index
|
|
55
|
+
mask = mask.all(dim=-1, keepdim=True)
|
|
56
|
+
preds = preds * mask
|
|
57
|
+
target = target * mask
|
|
58
|
+
|
|
59
|
+
super().update(preds=preds, target=target)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Mean Intersection over Union (mIoU) metric for semantic segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchmetrics
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MeanIoU(torchmetrics.Metric):
|
|
10
|
+
"""Computes Mean Intersection over Union (mIoU) for semantic segmentation.
|
|
11
|
+
|
|
12
|
+
Fixes the torchmetrics implementation
|
|
13
|
+
(issue https://github.com/Lightning-AI/torchmetrics/issues/2558)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
num_classes: int,
|
|
19
|
+
include_background: bool = True,
|
|
20
|
+
ignore_index: int | None = None,
|
|
21
|
+
per_class: bool = False,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the metric.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
num_classes: The number of classes in the segmentation problem.
|
|
28
|
+
include_background: Whether to include the background class in the computation
|
|
29
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
30
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
31
|
+
per_class: Whether to compute the IoU for each class separately. If set to ``False``,
|
|
32
|
+
the metric will compute the mean IoU over all classes.
|
|
33
|
+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
|
|
37
|
+
self.num_classes = num_classes
|
|
38
|
+
self.include_background = include_background
|
|
39
|
+
self.ignore_index = ignore_index
|
|
40
|
+
self.per_class = per_class
|
|
41
|
+
|
|
42
|
+
self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
|
43
|
+
self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
|
44
|
+
|
|
45
|
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
|
46
|
+
"""Update the state with the new data."""
|
|
47
|
+
intersection, union = _compute_intersection_and_union(
|
|
48
|
+
preds,
|
|
49
|
+
target,
|
|
50
|
+
num_classes=self.num_classes,
|
|
51
|
+
include_background=self.include_background,
|
|
52
|
+
ignore_index=self.ignore_index,
|
|
53
|
+
)
|
|
54
|
+
self.intersection += intersection.sum(0)
|
|
55
|
+
self.union += union.sum(0)
|
|
56
|
+
|
|
57
|
+
def compute(self) -> torch.Tensor:
|
|
58
|
+
"""Compute the final mean IoU score."""
|
|
59
|
+
iou_valid = torch.gt(self.union, 0)
|
|
60
|
+
iou = torch.where(
|
|
61
|
+
iou_valid,
|
|
62
|
+
torch.divide(self.intersection, self.union),
|
|
63
|
+
torch.nan,
|
|
64
|
+
)
|
|
65
|
+
if not self.per_class:
|
|
66
|
+
iou = torch.mean(iou[iou_valid])
|
|
67
|
+
return iou
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _compute_intersection_and_union(
|
|
71
|
+
preds: torch.Tensor,
|
|
72
|
+
target: torch.Tensor,
|
|
73
|
+
num_classes: int,
|
|
74
|
+
include_background: bool = False,
|
|
75
|
+
input_format: Literal["one-hot", "index"] = "index",
|
|
76
|
+
ignore_index: int | None = None,
|
|
77
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
78
|
+
"""Compute the intersection and union for semantic segmentation tasks.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
preds: Predicted tensor with shape (N, ...) where N is the batch size.
|
|
82
|
+
The shape can be (N, H, W) for 2D data or (N, D, H, W) for 3D data.
|
|
83
|
+
target: Ground truth tensor with the same shape as preds.
|
|
84
|
+
num_classes: Number of classes in the segmentation task.
|
|
85
|
+
include_background: Whether to include the background class in the computation.
|
|
86
|
+
input_format: Format of the input tensors.
|
|
87
|
+
ignore_index: Integer specifying a target class to ignore. If given, this class
|
|
88
|
+
index does not contribute to the returned score, regardless of reduction method.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Two tensors representing the intersection and union for each class.
|
|
92
|
+
Shape of each tensor is (N, num_classes).
|
|
93
|
+
|
|
94
|
+
Note:
|
|
95
|
+
- If input_format is "index", the tensors are converted to one-hot encoding.
|
|
96
|
+
- If include_background is `False`, the background class
|
|
97
|
+
(assumed to be the first channel) is ignored in the computation.
|
|
98
|
+
"""
|
|
99
|
+
if ignore_index is not None:
|
|
100
|
+
mask = target != ignore_index
|
|
101
|
+
mask = mask.all(dim=-1, keepdim=True)
|
|
102
|
+
preds = preds * mask
|
|
103
|
+
target = target * mask
|
|
104
|
+
|
|
105
|
+
if input_format == "index":
|
|
106
|
+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
|
|
107
|
+
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
|
|
108
|
+
|
|
109
|
+
if not include_background:
|
|
110
|
+
preds[..., 0] = 0
|
|
111
|
+
target[..., 0] = 0
|
|
112
|
+
|
|
113
|
+
reduce_axis = list(range(1, preds.ndim - 1))
|
|
114
|
+
|
|
115
|
+
intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
|
|
116
|
+
target_sum = torch.sum(target, dim=reduce_axis)
|
|
117
|
+
pred_sum = torch.sum(preds, dim=reduce_axis)
|
|
118
|
+
union = target_sum + pred_sum - intersection
|
|
119
|
+
|
|
120
|
+
return intersection, union
|
|
@@ -44,4 +44,6 @@ class MetricsSchema:
|
|
|
44
44
|
if metrics is None or self.common is None:
|
|
45
45
|
return self.common or metrics
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
metrics = metrics if isinstance(metrics, list) else [metrics] # type: ignore
|
|
48
|
+
common = self.common if isinstance(self.common, list) else [self.common]
|
|
49
|
+
return common + metrics # type: ignore
|
eva/core/models/__init__.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Models API."""
|
|
2
2
|
|
|
3
3
|
from eva.core.models.modules import HeadModule, InferenceModule
|
|
4
|
-
from eva.core.models.networks import MLP
|
|
4
|
+
from eva.core.models.networks import MLP
|
|
5
|
+
from eva.core.models.wrappers import BaseModel, HuggingFaceModel, ModelFromFunction, ONNXModel
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
"HeadModule",
|
|
8
9
|
"InferenceModule",
|
|
9
10
|
"MLP",
|
|
11
|
+
"BaseModel",
|
|
10
12
|
"HuggingFaceModel",
|
|
11
13
|
"ModelFromFunction",
|
|
12
14
|
"ONNXModel",
|
eva/core/models/modules/head.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
""""Neural Network Head Module."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
7
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
-
from torch import optim
|
|
8
|
+
from torch import nn, optim
|
|
9
9
|
from torch.optim import lr_scheduler
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
@@ -13,6 +13,7 @@ from eva.core.metrics import structs as metrics_lib
|
|
|
13
13
|
from eva.core.models.modules import module
|
|
14
14
|
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
|
|
15
15
|
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
|
+
from eva.core.utils import parser
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class HeadModule(module.ModelModule):
|
|
@@ -24,7 +25,7 @@ class HeadModule(module.ModelModule):
|
|
|
24
25
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
|
-
head: MODEL_TYPE,
|
|
28
|
+
head: Dict[str, Any] | MODEL_TYPE,
|
|
28
29
|
criterion: Callable[..., torch.Tensor],
|
|
29
30
|
backbone: MODEL_TYPE | None = None,
|
|
30
31
|
optimizer: OptimizerCallable = optim.Adam,
|
|
@@ -36,6 +37,8 @@ class HeadModule(module.ModelModule):
|
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
39
|
head: The neural network that would be trained on the features.
|
|
40
|
+
If its a dictionary, it will be parsed to an object during the
|
|
41
|
+
`configure_model` step.
|
|
39
42
|
criterion: The loss function to use.
|
|
40
43
|
backbone: The feature extractor. If `None`, it will be expected
|
|
41
44
|
that the input batch returns the features directly.
|
|
@@ -48,15 +51,23 @@ class HeadModule(module.ModelModule):
|
|
|
48
51
|
"""
|
|
49
52
|
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
50
53
|
|
|
51
|
-
self.head = head
|
|
54
|
+
self.head = head # type: ignore
|
|
52
55
|
self.criterion = criterion
|
|
53
56
|
self.backbone = backbone
|
|
54
57
|
self.optimizer = optimizer
|
|
55
58
|
self.lr_scheduler = lr_scheduler
|
|
56
59
|
|
|
60
|
+
@override
|
|
61
|
+
def configure_model(self) -> Any:
|
|
62
|
+
if self.backbone is not None:
|
|
63
|
+
grad.deactivate_requires_grad(self.backbone)
|
|
64
|
+
|
|
65
|
+
if isinstance(self.head, dict):
|
|
66
|
+
self.head: MODEL_TYPE = parser.parse_object(self.head, expected_type=nn.Module)
|
|
67
|
+
|
|
57
68
|
@override
|
|
58
69
|
def configure_optimizers(self) -> Any:
|
|
59
|
-
parameters =
|
|
70
|
+
parameters = self.head.parameters()
|
|
60
71
|
optimizer = self.optimizer(parameters)
|
|
61
72
|
lr_scheduler = self.lr_scheduler(optimizer)
|
|
62
73
|
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
|
@@ -66,11 +77,6 @@ class HeadModule(module.ModelModule):
|
|
|
66
77
|
features = tensor if self.backbone is None else self.backbone(tensor)
|
|
67
78
|
return self.head(features).squeeze(-1)
|
|
68
79
|
|
|
69
|
-
@override
|
|
70
|
-
def on_fit_start(self) -> None:
|
|
71
|
-
if self.backbone is not None:
|
|
72
|
-
grad.deactivate_requires_grad(self.backbone)
|
|
73
|
-
|
|
74
80
|
@override
|
|
75
81
|
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
76
82
|
return self._batch_step(batch)
|
|
@@ -88,11 +94,6 @@ class HeadModule(module.ModelModule):
|
|
|
88
94
|
tensor = INPUT_BATCH(*batch).data
|
|
89
95
|
return tensor if self.backbone is None else self.backbone(tensor)
|
|
90
96
|
|
|
91
|
-
@override
|
|
92
|
-
def on_fit_end(self) -> None:
|
|
93
|
-
if self.backbone is not None:
|
|
94
|
-
grad.activate_requires_grad(self.backbone)
|
|
95
|
-
|
|
96
97
|
def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
|
|
97
98
|
"""Performs a model forward step and calculates the loss.
|
|
98
99
|
|