kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
eva/core/trainers/trainer.py
CHANGED
|
@@ -3,11 +3,14 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
+
import loguru
|
|
6
7
|
from lightning.pytorch import loggers as pl_loggers
|
|
7
8
|
from lightning.pytorch import trainer as pl_trainer
|
|
8
9
|
from lightning.pytorch.utilities import argparse
|
|
10
|
+
from lightning_fabric.utilities import cloud_io
|
|
9
11
|
from typing_extensions import override
|
|
10
12
|
|
|
13
|
+
from eva.core import loggers as eva_loggers
|
|
11
14
|
from eva.core.data import datamodules
|
|
12
15
|
from eva.core.models import modules
|
|
13
16
|
from eva.core.trainers import _logging, functional
|
|
@@ -65,13 +68,23 @@ class Trainer(pl_trainer.Trainer):
|
|
|
65
68
|
subdirectory: Whether to append a subdirectory to the output log.
|
|
66
69
|
"""
|
|
67
70
|
self._log_dir = os.path.join(self.default_root_dir, self._session_id, subdirectory)
|
|
68
|
-
os.fspath(self._log_dir)
|
|
69
71
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
logger.
|
|
74
|
-
|
|
72
|
+
enabled_loggers = []
|
|
73
|
+
if isinstance(self.loggers, list) and len(self.loggers) > 0:
|
|
74
|
+
for logger in self.loggers:
|
|
75
|
+
if isinstance(logger, (pl_loggers.CSVLogger, pl_loggers.TensorBoardLogger)):
|
|
76
|
+
if not cloud_io._is_local_file_protocol(self.default_root_dir):
|
|
77
|
+
loguru.logger.warning(
|
|
78
|
+
f"Skipped {type(logger).__name__} as remote storage is not supported."
|
|
79
|
+
)
|
|
80
|
+
continue
|
|
81
|
+
else:
|
|
82
|
+
logger._root_dir = self.default_root_dir
|
|
83
|
+
logger._name = self._session_id
|
|
84
|
+
logger._version = subdirectory
|
|
85
|
+
enabled_loggers.append(logger)
|
|
86
|
+
|
|
87
|
+
self._loggers = enabled_loggers or [eva_loggers.DummyLogger(self._log_dir)]
|
|
75
88
|
|
|
76
89
|
def run_evaluation_session(
|
|
77
90
|
self,
|
|
@@ -94,4 +107,5 @@ class Trainer(pl_trainer.Trainer):
|
|
|
94
107
|
base_model=model,
|
|
95
108
|
datamodule=datamodule,
|
|
96
109
|
n_runs=self._n_runs,
|
|
110
|
+
verbose=self._n_runs > 1,
|
|
97
111
|
)
|
eva/core/utils/__init__.py
CHANGED
eva/core/utils/clone.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Clone related functions."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functools.singledispatch
|
|
10
|
+
def clone(tensor_type: Any) -> Any:
|
|
11
|
+
"""Clone tensor objects."""
|
|
12
|
+
raise TypeError(f"Unsupported input type: {type(input)}.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@clone.register
|
|
16
|
+
def _(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
return tensor.clone()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@clone.register
|
|
21
|
+
def _(tensors: list) -> List[torch.Tensor]:
|
|
22
|
+
return list(map(clone, tensors))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@clone.register
|
|
26
|
+
def _(tensors: dict) -> Dict[str, torch.Tensor]:
|
|
27
|
+
return {key: clone(tensors[key]) for key in tensors}
|
eva/core/utils/memory.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Memory related functions."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functools.singledispatch
|
|
10
|
+
def to_cpu(tensor_type: Any) -> Any:
|
|
11
|
+
"""Moves tensor objects to `cpu`."""
|
|
12
|
+
raise TypeError(f"Unsupported input type: {type(tensor_type)}.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@to_cpu.register
|
|
16
|
+
def _(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
detached_tensor = tensor.detach()
|
|
18
|
+
return detached_tensor.cpu()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@to_cpu.register
|
|
22
|
+
def _(tensors: list) -> List[torch.Tensor]:
|
|
23
|
+
return list(map(to_cpu, tensors))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@to_cpu.register
|
|
27
|
+
def _(tensors: dict) -> Dict[str, torch.Tensor]:
|
|
28
|
+
return {key: to_cpu(tensors[key]) for key in tensors}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Functional operations."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Iterable, List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def numeric_sort(item: Iterable[str], /) -> List[str]:
|
|
8
|
+
"""Sorts an iterable of strings treating embedded numbers as numeric values.
|
|
9
|
+
|
|
10
|
+
Here the strings are compared based on their numeric value rather than their
|
|
11
|
+
string representation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
item: An iterable of strings to be sorted.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
A list of strings sorted based on their numeric values.
|
|
18
|
+
"""
|
|
19
|
+
return sorted(
|
|
20
|
+
item,
|
|
21
|
+
key=lambda value: re.sub(
|
|
22
|
+
r"(\d+)",
|
|
23
|
+
lambda num: f"{int(num[0]):010d}",
|
|
24
|
+
value,
|
|
25
|
+
),
|
|
26
|
+
)
|
eva/core/utils/parser.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Parsing related helper functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
import jsonargparse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def parse_object(config: Dict[str, Any], expected_type: Any = Any) -> Any:
|
|
9
|
+
"""Parse object which is defined as dictionary."""
|
|
10
|
+
parser = jsonargparse.ArgumentParser()
|
|
11
|
+
parser.add_argument("module", type=expected_type)
|
|
12
|
+
configuration = parser.parse_object({"module": config})
|
|
13
|
+
init_object = parser.instantiate_classes(configuration)
|
|
14
|
+
obj_module = init_object.module
|
|
15
|
+
if isinstance(obj_module, jsonargparse.Namespace):
|
|
16
|
+
raise ValueError(
|
|
17
|
+
f"Failed to parsed object '{obj_module.class_path}'. "
|
|
18
|
+
"Please check that the initialized arguments are valid."
|
|
19
|
+
)
|
|
20
|
+
return obj_module
|
eva/vision/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""eva vision API."""
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
|
-
from eva.vision import models, utils
|
|
4
|
+
from eva.vision import callbacks, losses, models, utils
|
|
5
5
|
from eva.vision.data import datasets, transforms
|
|
6
6
|
except ImportError as e:
|
|
7
7
|
msg = (
|
|
@@ -11,4 +11,4 @@ except ImportError as e:
|
|
|
11
11
|
)
|
|
12
12
|
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
13
13
|
|
|
14
|
-
__all__ = ["models", "utils", "datasets", "transforms"]
|
|
14
|
+
__all__ = ["callbacks", "losses", "models", "utils", "datasets", "transforms"]
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Base batch callback logger."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
from lightning import pytorch as pl
|
|
6
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BatchLogger(pl.Callback, abc.ABC):
|
|
13
|
+
"""Logs training and validation batch assets."""
|
|
14
|
+
|
|
15
|
+
_batch_idx_to_log: int = 0
|
|
16
|
+
"""The batch index log."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
log_every_n_epochs: int | None = None,
|
|
21
|
+
log_every_n_steps: int | None = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes the callback object.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
log_every_n_epochs: Epoch-wise logging frequency.
|
|
27
|
+
log_every_n_steps: Step-wise logging frequency.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
if log_every_n_epochs is None and log_every_n_steps is None:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Please configure the logging frequency though "
|
|
34
|
+
"`log_every_n_epochs` or `log_every_n_steps`."
|
|
35
|
+
)
|
|
36
|
+
if None not in [log_every_n_epochs, log_every_n_steps]:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Arguments `log_every_n_epochs` and `log_every_n_steps` "
|
|
39
|
+
"are mutually exclusive. Please configure one of them."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self._log_every_n_epochs = log_every_n_epochs
|
|
43
|
+
self._log_every_n_steps = log_every_n_steps
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def on_train_batch_end(
|
|
47
|
+
self,
|
|
48
|
+
trainer: pl.Trainer,
|
|
49
|
+
pl_module: pl.LightningModule,
|
|
50
|
+
outputs: STEP_OUTPUT,
|
|
51
|
+
batch: INPUT_TENSOR_BATCH,
|
|
52
|
+
batch_idx: int,
|
|
53
|
+
) -> None:
|
|
54
|
+
if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None):
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
self._log_batch(
|
|
58
|
+
trainer=trainer,
|
|
59
|
+
batch=batch,
|
|
60
|
+
outputs=outputs,
|
|
61
|
+
tag="BatchTrain",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def on_validation_batch_end(
|
|
66
|
+
self,
|
|
67
|
+
trainer: pl.Trainer,
|
|
68
|
+
pl_module: pl.LightningModule,
|
|
69
|
+
outputs: STEP_OUTPUT,
|
|
70
|
+
batch: INPUT_TENSOR_BATCH,
|
|
71
|
+
batch_idx: int,
|
|
72
|
+
dataloader_idx: int = 0,
|
|
73
|
+
) -> None:
|
|
74
|
+
if self._skip_logging(trainer, batch_idx):
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
self._log_batch(
|
|
78
|
+
trainer=trainer,
|
|
79
|
+
batch=batch,
|
|
80
|
+
outputs=outputs,
|
|
81
|
+
tag="BatchValidation",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
def _log_batch(
|
|
86
|
+
self,
|
|
87
|
+
trainer: pl.Trainer,
|
|
88
|
+
outputs: STEP_OUTPUT,
|
|
89
|
+
batch: INPUT_TENSOR_BATCH,
|
|
90
|
+
tag: str,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Logs the batch data.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
trainer: The trainer.
|
|
96
|
+
outputs: The output of the train / val step.
|
|
97
|
+
batch: The data batch.
|
|
98
|
+
tag: The log tag.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def _skip_logging(
|
|
102
|
+
self,
|
|
103
|
+
trainer: pl.Trainer,
|
|
104
|
+
batch_idx: int | None = None,
|
|
105
|
+
) -> bool:
|
|
106
|
+
"""Determines whether skip the logging step or not.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
trainer: The trainer.
|
|
110
|
+
batch_idx: The batch index.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A boolean indicating whether to skip the step execution.
|
|
114
|
+
"""
|
|
115
|
+
if trainer.global_step in [0, 1]:
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
skip_due_frequency = any(
|
|
119
|
+
[
|
|
120
|
+
(trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0,
|
|
121
|
+
(trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0,
|
|
122
|
+
]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
conditions = [
|
|
126
|
+
skip_due_frequency,
|
|
127
|
+
not trainer.is_global_zero,
|
|
128
|
+
batch_idx != self._batch_idx_to_log if batch_idx else False,
|
|
129
|
+
]
|
|
130
|
+
return any(conditions)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Segmentation datasets related data loggers."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from lightning import pytorch as pl
|
|
8
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
|
+
from torch.nn import functional
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.loggers import log
|
|
13
|
+
from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
|
|
14
|
+
from eva.core.utils import to_cpu
|
|
15
|
+
from eva.vision.callbacks.loggers.batch import base
|
|
16
|
+
from eva.vision.utils import colormap, convert
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SemanticSegmentationLogger(base.BatchLogger):
|
|
20
|
+
"""Log the segmentation batch."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
max_samples: int = 10,
|
|
25
|
+
number_of_images_per_subgrid_row: int = 1,
|
|
26
|
+
log_images: bool = True,
|
|
27
|
+
mean: Tuple[float, ...] = (0.0, 0.0, 0.0),
|
|
28
|
+
std: Tuple[float, ...] = (1.0, 1.0, 1.0),
|
|
29
|
+
log_every_n_epochs: int | None = None,
|
|
30
|
+
log_every_n_steps: int | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initializes the callback object.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
max_samples: The maximum number of images displayed in the grid.
|
|
36
|
+
number_of_images_per_subgrid_row: Number of images displayed in each
|
|
37
|
+
row of each sub-grid (that is images, targets and predictions).
|
|
38
|
+
log_images: Whether to log the input batch images.
|
|
39
|
+
mean: The mean of the input images to de-normalize from.
|
|
40
|
+
std: The std of the input images to de-normalize from.
|
|
41
|
+
log_every_n_epochs: Epoch-wise logging frequency.
|
|
42
|
+
log_every_n_steps: Step-wise logging frequency.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(
|
|
45
|
+
log_every_n_epochs=log_every_n_epochs,
|
|
46
|
+
log_every_n_steps=log_every_n_steps,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self._max_samples = max_samples
|
|
50
|
+
self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row
|
|
51
|
+
self._log_images = log_images
|
|
52
|
+
self._mean = mean
|
|
53
|
+
self._std = std
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def _log_batch(
|
|
57
|
+
self,
|
|
58
|
+
trainer: pl.Trainer,
|
|
59
|
+
outputs: STEP_OUTPUT,
|
|
60
|
+
batch: INPUT_TENSOR_BATCH,
|
|
61
|
+
tag: str,
|
|
62
|
+
) -> None:
|
|
63
|
+
predictions = outputs.get("predictions") if isinstance(outputs, dict) else None
|
|
64
|
+
if predictions is None:
|
|
65
|
+
raise ValueError("Key `predictions` is missing from the output.")
|
|
66
|
+
|
|
67
|
+
data_batch, target_batch = batch[0], batch[1]
|
|
68
|
+
data, targets, predictions = _subsample_tensors(
|
|
69
|
+
tensors_stack=[data_batch, target_batch, predictions],
|
|
70
|
+
max_samples=self._max_samples,
|
|
71
|
+
)
|
|
72
|
+
data, targets, predictions = to_cpu([data, targets, predictions])
|
|
73
|
+
predictions = torch.argmax(predictions, dim=1)
|
|
74
|
+
|
|
75
|
+
target_images = list(map(_draw_semantic_mask, targets))
|
|
76
|
+
prediction_images = list(map(_draw_semantic_mask, predictions))
|
|
77
|
+
image_groups = [target_images, prediction_images]
|
|
78
|
+
|
|
79
|
+
if self._log_images:
|
|
80
|
+
images = list(map(self._format_image, data))
|
|
81
|
+
overlay_targets = [
|
|
82
|
+
_overlay_mask(image, mask) for image, mask in zip(images, targets, strict=False)
|
|
83
|
+
]
|
|
84
|
+
overlay_predictions = [
|
|
85
|
+
_overlay_mask(image, mask) for image, mask in zip(images, predictions, strict=False)
|
|
86
|
+
]
|
|
87
|
+
image_groups = [images, overlay_targets, overlay_predictions] + image_groups
|
|
88
|
+
|
|
89
|
+
image_grid = _make_grid_from_image_groups(
|
|
90
|
+
image_groups, self._number_of_images_per_subgrid_row
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
log.log_image(
|
|
94
|
+
trainer.loggers,
|
|
95
|
+
image=image_grid,
|
|
96
|
+
tag=tag,
|
|
97
|
+
step=trainer.global_step,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _format_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
"""Descaled an image tensor to (0, 255) uint8 tensor."""
|
|
102
|
+
return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _subsample_tensors(
|
|
106
|
+
tensors_stack: List[torch.Tensor],
|
|
107
|
+
max_samples: int,
|
|
108
|
+
) -> List[torch.Tensor]:
|
|
109
|
+
"""Sub-samples tensors from a list of tensors in-place.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
tensors_stack: A list of tensors.
|
|
113
|
+
max_samples: The maximum number of images
|
|
114
|
+
displayed in the grid.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
A sub-sample of the input tensors stack.
|
|
118
|
+
"""
|
|
119
|
+
for i, tensor in enumerate(tensors_stack):
|
|
120
|
+
tensors_stack[i] = tensor[:max_samples]
|
|
121
|
+
return tensors_stack
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Draws a semantic mask to an image RGB tensor.
|
|
126
|
+
|
|
127
|
+
The input semantic mask is a (H x W) shaped tensor with
|
|
128
|
+
integer values which represent the pixel class id.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
tensor: An image tensor of range [0., 1.].
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The image as a tensor of range [0., 255.].
|
|
135
|
+
"""
|
|
136
|
+
tensor = torch.squeeze(tensor)
|
|
137
|
+
height, width = tensor.shape[-2], tensor.shape[-1]
|
|
138
|
+
red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8)
|
|
139
|
+
for class_id, color in colormap.COLORMAP.items():
|
|
140
|
+
indices = tensor == class_id
|
|
141
|
+
red[indices], green[indices], blue[indices] = color
|
|
142
|
+
return torch.stack([red, green, blue])
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""Overlays a segmentation mask onto an image.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
image: A 3D tensor of shape (C, H, W) representing the image.
|
|
150
|
+
mask: A 2D tensor of shape (H, W) representing the segmentation mask.
|
|
151
|
+
Each pixel in the mask corresponds to a class label.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A tensor of the same shape as the input image (C, H, W) with the
|
|
155
|
+
segmentation mask overlaid on top. The output image retains the
|
|
156
|
+
original color channels but with the mask applied, using the colors
|
|
157
|
+
from the predefined colormap.
|
|
158
|
+
"""
|
|
159
|
+
binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool)
|
|
160
|
+
return torchvision.utils.draw_segmentation_masks(
|
|
161
|
+
image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _make_grid_from_image_groups(
|
|
166
|
+
image_groups: List[List[torch.Tensor]],
|
|
167
|
+
number_of_images_per_subgrid_row: int = 2,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
"""Creates a single image grid from image groups.
|
|
170
|
+
|
|
171
|
+
For example, it can combine the input images, targets predictions into a single image.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
image_groups: A list of lists of image tensors of shape (C x H x W)
|
|
175
|
+
all of the same size.
|
|
176
|
+
number_of_images_per_subgrid_row: Number of images displayed in each
|
|
177
|
+
row of the sub-grid.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
An image grid as a `torch.Tensor`.
|
|
181
|
+
"""
|
|
182
|
+
return torchvision.utils.make_grid(
|
|
183
|
+
[
|
|
184
|
+
torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row)
|
|
185
|
+
for image_group in image_groups
|
|
186
|
+
],
|
|
187
|
+
nrow=len(image_groups),
|
|
188
|
+
)
|
|
@@ -4,19 +4,39 @@ from eva.vision.data.datasets.classification import (
|
|
|
4
4
|
BACH,
|
|
5
5
|
CRC,
|
|
6
6
|
MHIST,
|
|
7
|
+
PANDA,
|
|
8
|
+
Camelyon16,
|
|
7
9
|
PatchCamelyon,
|
|
8
|
-
|
|
10
|
+
WsiClassificationDataset,
|
|
11
|
+
)
|
|
12
|
+
from eva.vision.data.datasets.segmentation import (
|
|
13
|
+
BCSS,
|
|
14
|
+
CoNSeP,
|
|
15
|
+
EmbeddingsSegmentationDataset,
|
|
16
|
+
ImageSegmentation,
|
|
17
|
+
LiTS,
|
|
18
|
+
MoNuSAC,
|
|
19
|
+
TotalSegmentator2D,
|
|
9
20
|
)
|
|
10
|
-
from eva.vision.data.datasets.segmentation import ImageSegmentation, TotalSegmentator2D
|
|
11
21
|
from eva.vision.data.datasets.vision import VisionDataset
|
|
22
|
+
from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
|
|
12
23
|
|
|
13
24
|
__all__ = [
|
|
14
25
|
"BACH",
|
|
26
|
+
"BCSS",
|
|
15
27
|
"CRC",
|
|
16
28
|
"MHIST",
|
|
17
|
-
"
|
|
29
|
+
"PANDA",
|
|
30
|
+
"Camelyon16",
|
|
18
31
|
"PatchCamelyon",
|
|
19
|
-
"
|
|
32
|
+
"WsiClassificationDataset",
|
|
33
|
+
"CoNSeP",
|
|
34
|
+
"EmbeddingsSegmentationDataset",
|
|
35
|
+
"ImageSegmentation",
|
|
36
|
+
"LiTS",
|
|
37
|
+
"MoNuSAC",
|
|
20
38
|
"TotalSegmentator2D",
|
|
21
39
|
"VisionDataset",
|
|
40
|
+
"MultiWsiDataset",
|
|
41
|
+
"WsiDataset",
|
|
22
42
|
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Dataset related function and helper functions."""
|
|
2
2
|
|
|
3
|
-
from typing import List, Tuple
|
|
3
|
+
from typing import List, Sequence, Tuple
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
|
|
@@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
|
|
|
33
33
|
return ranges
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def ranges_to_indices(ranges:
|
|
36
|
+
def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
|
|
37
37
|
"""Unpacks a list of ranges to individual indices.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
-
ranges:
|
|
40
|
+
ranges: A sequence of ranges to produce the indices from.
|
|
41
41
|
|
|
42
42
|
Return:
|
|
43
43
|
A list of the indices.
|
|
@@ -13,7 +13,7 @@ _SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and
|
|
|
13
13
|
def check_dataset_integrity(
|
|
14
14
|
dataset: vision.VisionDataset,
|
|
15
15
|
*,
|
|
16
|
-
length: int,
|
|
16
|
+
length: int | None,
|
|
17
17
|
n_classes: int,
|
|
18
18
|
first_and_last_labels: Tuple[str, str],
|
|
19
19
|
) -> None:
|
|
@@ -23,7 +23,7 @@ def check_dataset_integrity(
|
|
|
23
23
|
ValueError: If the input dataset's values do not
|
|
24
24
|
match the expected ones.
|
|
25
25
|
"""
|
|
26
|
-
if len(dataset) != length:
|
|
26
|
+
if length and len(dataset) != length:
|
|
27
27
|
raise ValueError(
|
|
28
28
|
f"Dataset's '{dataset.__class__.__qualname__}' length "
|
|
29
29
|
f"({len(dataset)}) does not match the expected one ({length}). "
|
|
@@ -57,3 +57,16 @@ def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
|
|
|
57
57
|
if download_available:
|
|
58
58
|
error_message += " You can set `download=True` to download the dataset automatically."
|
|
59
59
|
raise FileNotFoundError(error_message)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def check_number_of_files(file_paths: List[str], expected_length: int, split: str | None) -> None:
|
|
63
|
+
"""Verifies the number of files in the dataset.
|
|
64
|
+
|
|
65
|
+
Raise:
|
|
66
|
+
ValueError: If the number of files in the dataset does not match the expected one.
|
|
67
|
+
"""
|
|
68
|
+
if len(file_paths) != expected_length:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Expected {expected_length} files, for split '{split}' found {len(file_paths)}. "
|
|
71
|
+
f"{_SUFFIX_ERROR_MESSAGE}"
|
|
72
|
+
)
|
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
"""Image classification datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.classification.bach import BACH
|
|
4
|
+
from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
|
|
4
5
|
from eva.vision.data.datasets.classification.crc import CRC
|
|
5
6
|
from eva.vision.data.datasets.classification.mhist import MHIST
|
|
7
|
+
from eva.vision.data.datasets.classification.panda import PANDA
|
|
6
8
|
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
|
|
7
|
-
from eva.vision.data.datasets.classification.
|
|
9
|
+
from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
|
|
8
10
|
|
|
9
11
|
__all__ = [
|
|
10
12
|
"BACH",
|
|
11
13
|
"CRC",
|
|
12
14
|
"MHIST",
|
|
13
15
|
"PatchCamelyon",
|
|
14
|
-
"
|
|
16
|
+
"WsiClassificationDataset",
|
|
17
|
+
"PANDA",
|
|
18
|
+
"Camelyon16",
|
|
15
19
|
]
|