kaiko-eva 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Base batch callback logger."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
from lightning import pytorch as pl
|
|
6
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BatchLogger(pl.Callback, abc.ABC):
|
|
13
|
+
"""Logs training and validation batch assets."""
|
|
14
|
+
|
|
15
|
+
_batch_idx_to_log: int = 0
|
|
16
|
+
"""The batch index log."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
log_every_n_epochs: int | None = None,
|
|
21
|
+
log_every_n_steps: int | None = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes the callback object.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
log_every_n_epochs: Epoch-wise logging frequency.
|
|
27
|
+
log_every_n_steps: Step-wise logging frequency.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
if log_every_n_epochs is None and log_every_n_steps is None:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Please configure the logging frequency though "
|
|
34
|
+
"`log_every_n_epochs` or `log_every_n_steps`."
|
|
35
|
+
)
|
|
36
|
+
if None not in [log_every_n_epochs, log_every_n_steps]:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Arguments `log_every_n_epochs` and `log_every_n_steps` "
|
|
39
|
+
"are mutually exclusive. Please configure one of them."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self._log_every_n_epochs = log_every_n_epochs
|
|
43
|
+
self._log_every_n_steps = log_every_n_steps
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def on_train_batch_end(
|
|
47
|
+
self,
|
|
48
|
+
trainer: pl.Trainer,
|
|
49
|
+
pl_module: pl.LightningModule,
|
|
50
|
+
outputs: STEP_OUTPUT,
|
|
51
|
+
batch: INPUT_TENSOR_BATCH,
|
|
52
|
+
batch_idx: int,
|
|
53
|
+
) -> None:
|
|
54
|
+
if self._skip_logging(trainer, batch_idx if self._log_every_n_epochs else None):
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
self._log_batch(
|
|
58
|
+
trainer=trainer,
|
|
59
|
+
batch=batch,
|
|
60
|
+
outputs=outputs,
|
|
61
|
+
tag="BatchTrain",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def on_validation_batch_end(
|
|
66
|
+
self,
|
|
67
|
+
trainer: pl.Trainer,
|
|
68
|
+
pl_module: pl.LightningModule,
|
|
69
|
+
outputs: STEP_OUTPUT,
|
|
70
|
+
batch: INPUT_TENSOR_BATCH,
|
|
71
|
+
batch_idx: int,
|
|
72
|
+
dataloader_idx: int = 0,
|
|
73
|
+
) -> None:
|
|
74
|
+
if self._skip_logging(trainer, batch_idx):
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
self._log_batch(
|
|
78
|
+
trainer=trainer,
|
|
79
|
+
batch=batch,
|
|
80
|
+
outputs=outputs,
|
|
81
|
+
tag="BatchValidation",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
def _log_batch(
|
|
86
|
+
self,
|
|
87
|
+
trainer: pl.Trainer,
|
|
88
|
+
outputs: STEP_OUTPUT,
|
|
89
|
+
batch: INPUT_TENSOR_BATCH,
|
|
90
|
+
tag: str,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Logs the batch data.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
trainer: The trainer.
|
|
96
|
+
outputs: The output of the train / val step.
|
|
97
|
+
batch: The data batch.
|
|
98
|
+
tag: The log tag.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def _skip_logging(
|
|
102
|
+
self,
|
|
103
|
+
trainer: pl.Trainer,
|
|
104
|
+
batch_idx: int | None = None,
|
|
105
|
+
) -> bool:
|
|
106
|
+
"""Determines whether skip the logging step or not.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
trainer: The trainer.
|
|
110
|
+
batch_idx: The batch index.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A boolean indicating whether to skip the step execution.
|
|
114
|
+
"""
|
|
115
|
+
if trainer.global_step in [0, 1]:
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
skip_due_frequency = any(
|
|
119
|
+
[
|
|
120
|
+
(trainer.current_epoch + 1) % (self._log_every_n_epochs or 1) != 0,
|
|
121
|
+
(trainer.global_step + 1) % (self._log_every_n_steps or 1) != 0,
|
|
122
|
+
]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
conditions = [
|
|
126
|
+
skip_due_frequency,
|
|
127
|
+
not trainer.is_global_zero,
|
|
128
|
+
batch_idx != self._batch_idx_to_log if batch_idx else False,
|
|
129
|
+
]
|
|
130
|
+
return any(conditions)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Segmentation datasets related data loggers."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from lightning import pytorch as pl
|
|
8
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
|
+
from torch.nn import functional
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.loggers import log
|
|
13
|
+
from eva.core.models.modules.typings import INPUT_TENSOR_BATCH
|
|
14
|
+
from eva.core.utils import to_cpu
|
|
15
|
+
from eva.vision.callbacks.loggers.batch import base
|
|
16
|
+
from eva.vision.utils import colormap, convert
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SemanticSegmentationLogger(base.BatchLogger):
|
|
20
|
+
"""Log the segmentation batch."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
max_samples: int = 10,
|
|
25
|
+
number_of_images_per_subgrid_row: int = 1,
|
|
26
|
+
log_images: bool = True,
|
|
27
|
+
mean: Tuple[float, ...] = (0.0, 0.0, 0.0),
|
|
28
|
+
std: Tuple[float, ...] = (1.0, 1.0, 1.0),
|
|
29
|
+
log_every_n_epochs: int | None = None,
|
|
30
|
+
log_every_n_steps: int | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initializes the callback object.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
max_samples: The maximum number of images displayed in the grid.
|
|
36
|
+
number_of_images_per_subgrid_row: Number of images displayed in each
|
|
37
|
+
row of each sub-grid (that is images, targets and predictions).
|
|
38
|
+
log_images: Whether to log the input batch images.
|
|
39
|
+
mean: The mean of the input images to de-normalize from.
|
|
40
|
+
std: The std of the input images to de-normalize from.
|
|
41
|
+
log_every_n_epochs: Epoch-wise logging frequency.
|
|
42
|
+
log_every_n_steps: Step-wise logging frequency.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(
|
|
45
|
+
log_every_n_epochs=log_every_n_epochs,
|
|
46
|
+
log_every_n_steps=log_every_n_steps,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self._max_samples = max_samples
|
|
50
|
+
self._number_of_images_per_subgrid_row = number_of_images_per_subgrid_row
|
|
51
|
+
self._log_images = log_images
|
|
52
|
+
self._mean = mean
|
|
53
|
+
self._std = std
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def _log_batch(
|
|
57
|
+
self,
|
|
58
|
+
trainer: pl.Trainer,
|
|
59
|
+
outputs: STEP_OUTPUT,
|
|
60
|
+
batch: INPUT_TENSOR_BATCH,
|
|
61
|
+
tag: str,
|
|
62
|
+
) -> None:
|
|
63
|
+
predictions = outputs.get("predictions") if isinstance(outputs, dict) else None
|
|
64
|
+
if predictions is None:
|
|
65
|
+
raise ValueError("Key `predictions` is missing from the output.")
|
|
66
|
+
|
|
67
|
+
data_batch, target_batch = batch[0], batch[1]
|
|
68
|
+
data, targets, predictions = _subsample_tensors(
|
|
69
|
+
tensors_stack=[data_batch, target_batch, predictions],
|
|
70
|
+
max_samples=self._max_samples,
|
|
71
|
+
)
|
|
72
|
+
data, targets, predictions = to_cpu([data, targets, predictions])
|
|
73
|
+
predictions = torch.argmax(predictions, dim=1)
|
|
74
|
+
|
|
75
|
+
target_images = list(map(_draw_semantic_mask, targets))
|
|
76
|
+
prediction_images = list(map(_draw_semantic_mask, predictions))
|
|
77
|
+
image_groups = [target_images, prediction_images]
|
|
78
|
+
|
|
79
|
+
if self._log_images:
|
|
80
|
+
images = list(map(self._format_image, data))
|
|
81
|
+
overlay_targets = [
|
|
82
|
+
_overlay_mask(image, mask) for image, mask in zip(images, targets, strict=False)
|
|
83
|
+
]
|
|
84
|
+
overlay_predictions = [
|
|
85
|
+
_overlay_mask(image, mask) for image, mask in zip(images, predictions, strict=False)
|
|
86
|
+
]
|
|
87
|
+
image_groups = [images, overlay_targets, overlay_predictions] + image_groups
|
|
88
|
+
|
|
89
|
+
image_grid = _make_grid_from_image_groups(
|
|
90
|
+
image_groups, self._number_of_images_per_subgrid_row
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
log.log_image(
|
|
94
|
+
trainer.loggers,
|
|
95
|
+
image=image_grid,
|
|
96
|
+
tag=tag,
|
|
97
|
+
step=trainer.global_step,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _format_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
"""Descaled an image tensor to (0, 255) uint8 tensor."""
|
|
102
|
+
return convert.descale_and_denorm_image(image, mean=self._mean, std=self._std)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _subsample_tensors(
|
|
106
|
+
tensors_stack: List[torch.Tensor],
|
|
107
|
+
max_samples: int,
|
|
108
|
+
) -> List[torch.Tensor]:
|
|
109
|
+
"""Sub-samples tensors from a list of tensors in-place.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
tensors_stack: A list of tensors.
|
|
113
|
+
max_samples: The maximum number of images
|
|
114
|
+
displayed in the grid.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
A sub-sample of the input tensors stack.
|
|
118
|
+
"""
|
|
119
|
+
for i, tensor in enumerate(tensors_stack):
|
|
120
|
+
tensors_stack[i] = tensor[:max_samples]
|
|
121
|
+
return tensors_stack
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Draws a semantic mask to an image RGB tensor.
|
|
126
|
+
|
|
127
|
+
The input semantic mask is a (H x W) shaped tensor with
|
|
128
|
+
integer values which represent the pixel class id.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
tensor: An image tensor of range [0., 1.].
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The image as a tensor of range [0., 255.].
|
|
135
|
+
"""
|
|
136
|
+
tensor = torch.squeeze(tensor)
|
|
137
|
+
height, width = tensor.shape[-2], tensor.shape[-1]
|
|
138
|
+
red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8)
|
|
139
|
+
for class_id, color in colormap.COLORMAP.items():
|
|
140
|
+
indices = tensor == class_id
|
|
141
|
+
red[indices], green[indices], blue[indices] = color
|
|
142
|
+
return torch.stack([red, green, blue])
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""Overlays a segmentation mask onto an image.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
image: A 3D tensor of shape (C, H, W) representing the image.
|
|
150
|
+
mask: A 2D tensor of shape (H, W) representing the segmentation mask.
|
|
151
|
+
Each pixel in the mask corresponds to a class label.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A tensor of the same shape as the input image (C, H, W) with the
|
|
155
|
+
segmentation mask overlaid on top. The output image retains the
|
|
156
|
+
original color channels but with the mask applied, using the colors
|
|
157
|
+
from the predefined colormap.
|
|
158
|
+
"""
|
|
159
|
+
binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool)
|
|
160
|
+
return torchvision.utils.draw_segmentation_masks(
|
|
161
|
+
image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _make_grid_from_image_groups(
|
|
166
|
+
image_groups: List[List[torch.Tensor]],
|
|
167
|
+
number_of_images_per_subgrid_row: int = 2,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
"""Creates a single image grid from image groups.
|
|
170
|
+
|
|
171
|
+
For example, it can combine the input images, targets predictions into a single image.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
image_groups: A list of lists of image tensors of shape (C x H x W)
|
|
175
|
+
all of the same size.
|
|
176
|
+
number_of_images_per_subgrid_row: Number of images displayed in each
|
|
177
|
+
row of the sub-grid.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
An image grid as a `torch.Tensor`.
|
|
181
|
+
"""
|
|
182
|
+
return torchvision.utils.make_grid(
|
|
183
|
+
[
|
|
184
|
+
torchvision.utils.make_grid(image_group, nrow=number_of_images_per_subgrid_row)
|
|
185
|
+
for image_group in image_groups
|
|
186
|
+
],
|
|
187
|
+
nrow=len(image_groups),
|
|
188
|
+
)
|
|
@@ -1,15 +1,42 @@
|
|
|
1
1
|
"""Vision Datasets API."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.datasets.classification import
|
|
4
|
-
|
|
3
|
+
from eva.vision.data.datasets.classification import (
|
|
4
|
+
BACH,
|
|
5
|
+
CRC,
|
|
6
|
+
MHIST,
|
|
7
|
+
PANDA,
|
|
8
|
+
Camelyon16,
|
|
9
|
+
PatchCamelyon,
|
|
10
|
+
WsiClassificationDataset,
|
|
11
|
+
)
|
|
12
|
+
from eva.vision.data.datasets.segmentation import (
|
|
13
|
+
BCSS,
|
|
14
|
+
CoNSeP,
|
|
15
|
+
EmbeddingsSegmentationDataset,
|
|
16
|
+
ImageSegmentation,
|
|
17
|
+
LiTS,
|
|
18
|
+
MoNuSAC,
|
|
19
|
+
TotalSegmentator2D,
|
|
20
|
+
)
|
|
5
21
|
from eva.vision.data.datasets.vision import VisionDataset
|
|
22
|
+
from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
|
|
6
23
|
|
|
7
24
|
__all__ = [
|
|
8
25
|
"BACH",
|
|
26
|
+
"BCSS",
|
|
9
27
|
"CRC",
|
|
10
28
|
"MHIST",
|
|
11
|
-
"
|
|
29
|
+
"PANDA",
|
|
30
|
+
"Camelyon16",
|
|
12
31
|
"PatchCamelyon",
|
|
32
|
+
"WsiClassificationDataset",
|
|
33
|
+
"CoNSeP",
|
|
34
|
+
"EmbeddingsSegmentationDataset",
|
|
35
|
+
"ImageSegmentation",
|
|
36
|
+
"LiTS",
|
|
37
|
+
"MoNuSAC",
|
|
13
38
|
"TotalSegmentator2D",
|
|
14
39
|
"VisionDataset",
|
|
40
|
+
"MultiWsiDataset",
|
|
41
|
+
"WsiDataset",
|
|
15
42
|
]
|
|
@@ -13,7 +13,7 @@ _SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and
|
|
|
13
13
|
def check_dataset_integrity(
|
|
14
14
|
dataset: vision.VisionDataset,
|
|
15
15
|
*,
|
|
16
|
-
length: int,
|
|
16
|
+
length: int | None,
|
|
17
17
|
n_classes: int,
|
|
18
18
|
first_and_last_labels: Tuple[str, str],
|
|
19
19
|
) -> None:
|
|
@@ -23,7 +23,7 @@ def check_dataset_integrity(
|
|
|
23
23
|
ValueError: If the input dataset's values do not
|
|
24
24
|
match the expected ones.
|
|
25
25
|
"""
|
|
26
|
-
if len(dataset) != length:
|
|
26
|
+
if length and len(dataset) != length:
|
|
27
27
|
raise ValueError(
|
|
28
28
|
f"Dataset's '{dataset.__class__.__qualname__}' length "
|
|
29
29
|
f"({len(dataset)}) does not match the expected one ({length}). "
|
|
@@ -57,3 +57,16 @@ def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
|
|
|
57
57
|
if download_available:
|
|
58
58
|
error_message += " You can set `download=True` to download the dataset automatically."
|
|
59
59
|
raise FileNotFoundError(error_message)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def check_number_of_files(file_paths: List[str], expected_length: int, split: str | None) -> None:
|
|
63
|
+
"""Verifies the number of files in the dataset.
|
|
64
|
+
|
|
65
|
+
Raise:
|
|
66
|
+
ValueError: If the number of files in the dataset does not match the expected one.
|
|
67
|
+
"""
|
|
68
|
+
if len(file_paths) != expected_length:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Expected {expected_length} files, for split '{split}' found {len(file_paths)}. "
|
|
71
|
+
f"{_SUFFIX_ERROR_MESSAGE}"
|
|
72
|
+
)
|
|
@@ -1,8 +1,19 @@
|
|
|
1
1
|
"""Image classification datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.classification.bach import BACH
|
|
4
|
+
from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
|
|
4
5
|
from eva.vision.data.datasets.classification.crc import CRC
|
|
5
6
|
from eva.vision.data.datasets.classification.mhist import MHIST
|
|
7
|
+
from eva.vision.data.datasets.classification.panda import PANDA
|
|
6
8
|
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
|
|
9
|
+
from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
|
|
7
10
|
|
|
8
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BACH",
|
|
13
|
+
"CRC",
|
|
14
|
+
"MHIST",
|
|
15
|
+
"PatchCamelyon",
|
|
16
|
+
"WsiClassificationDataset",
|
|
17
|
+
"PANDA",
|
|
18
|
+
"Camelyon16",
|
|
19
|
+
]
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.datasets import folder, utils
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -52,8 +53,7 @@ class BACH(base.ImageClassification):
|
|
|
52
53
|
root: str,
|
|
53
54
|
split: Literal["train", "val"] | None = None,
|
|
54
55
|
download: bool = False,
|
|
55
|
-
|
|
56
|
-
target_transforms: Callable | None = None,
|
|
56
|
+
transforms: Callable | None = None,
|
|
57
57
|
) -> None:
|
|
58
58
|
"""Initialize the dataset.
|
|
59
59
|
|
|
@@ -68,15 +68,10 @@ class BACH(base.ImageClassification):
|
|
|
68
68
|
Note that the download will be executed only by additionally
|
|
69
69
|
calling the :meth:`prepare_data` method and if the data does
|
|
70
70
|
not yet exist on disk.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
target_transforms: A function/transform that takes in the target
|
|
74
|
-
and transforms it.
|
|
71
|
+
transforms: A function/transform which returns a transformed
|
|
72
|
+
version of the raw data samples.
|
|
75
73
|
"""
|
|
76
|
-
super().__init__(
|
|
77
|
-
image_transforms=image_transforms,
|
|
78
|
-
target_transforms=target_transforms,
|
|
79
|
-
)
|
|
74
|
+
super().__init__(transforms=transforms)
|
|
80
75
|
|
|
81
76
|
self._root = root
|
|
82
77
|
self._split = split
|
|
@@ -130,14 +125,14 @@ class BACH(base.ImageClassification):
|
|
|
130
125
|
)
|
|
131
126
|
|
|
132
127
|
@override
|
|
133
|
-
def load_image(self, index: int) ->
|
|
128
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
134
129
|
image_path, _ = self._samples[self._indices[index]]
|
|
135
|
-
return io.
|
|
130
|
+
return io.read_image_as_tensor(image_path)
|
|
136
131
|
|
|
137
132
|
@override
|
|
138
|
-
def load_target(self, index: int) ->
|
|
133
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
139
134
|
_, target = self._samples[self._indices[index]]
|
|
140
|
-
return
|
|
135
|
+
return torch.tensor(target, dtype=torch.long)
|
|
141
136
|
|
|
142
137
|
@override
|
|
143
138
|
def __len__(self) -> int:
|
|
@@ -3,32 +3,29 @@
|
|
|
3
3
|
import abc
|
|
4
4
|
from typing import Any, Callable, Dict, List, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.vision.data.datasets import vision
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
class ImageClassification(vision.VisionDataset[Tuple[
|
|
13
|
+
class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
|
|
13
14
|
"""Image classification abstract dataset."""
|
|
14
15
|
|
|
15
16
|
def __init__(
|
|
16
17
|
self,
|
|
17
|
-
|
|
18
|
-
target_transforms: Callable | None = None,
|
|
18
|
+
transforms: Callable | None = None,
|
|
19
19
|
) -> None:
|
|
20
20
|
"""Initializes the image classification dataset.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
target_transforms: A function/transform that takes in the target
|
|
26
|
-
and transforms it.
|
|
23
|
+
transforms: A function/transform which returns a transformed
|
|
24
|
+
version of the raw data samples.
|
|
27
25
|
"""
|
|
28
26
|
super().__init__()
|
|
29
27
|
|
|
30
|
-
self.
|
|
31
|
-
self._target_transforms = target_transforms
|
|
28
|
+
self._transforms = transforms
|
|
32
29
|
|
|
33
30
|
@property
|
|
34
31
|
def classes(self) -> List[str] | None:
|
|
@@ -38,19 +35,18 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
38
35
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
39
36
|
"""Returns a mapping of the class name to its target index."""
|
|
40
37
|
|
|
41
|
-
def load_metadata(self, index: int
|
|
38
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
42
39
|
"""Returns the dataset metadata.
|
|
43
40
|
|
|
44
41
|
Args:
|
|
45
42
|
index: The index of the data sample to return the metadata of.
|
|
46
|
-
If `None`, it will return the metadata of the current dataset.
|
|
47
43
|
|
|
48
44
|
Returns:
|
|
49
45
|
The sample metadata.
|
|
50
46
|
"""
|
|
51
47
|
|
|
52
48
|
@abc.abstractmethod
|
|
53
|
-
def load_image(self, index: int) ->
|
|
49
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
54
50
|
"""Returns the `index`'th image sample.
|
|
55
51
|
|
|
56
52
|
Args:
|
|
@@ -61,7 +57,7 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
61
57
|
"""
|
|
62
58
|
|
|
63
59
|
@abc.abstractmethod
|
|
64
|
-
def load_target(self, index: int) ->
|
|
60
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
65
61
|
"""Returns the `index`'th target sample.
|
|
66
62
|
|
|
67
63
|
Args:
|
|
@@ -77,14 +73,15 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
77
73
|
raise NotImplementedError
|
|
78
74
|
|
|
79
75
|
@override
|
|
80
|
-
def __getitem__(self, index: int) -> Tuple[
|
|
76
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
81
77
|
image = self.load_image(index)
|
|
82
78
|
target = self.load_target(index)
|
|
83
|
-
|
|
79
|
+
image, target = self._apply_transforms(image, target)
|
|
80
|
+
return image, target, self.load_metadata(index) or {}
|
|
84
81
|
|
|
85
82
|
def _apply_transforms(
|
|
86
|
-
self, image:
|
|
87
|
-
) -> Tuple[
|
|
83
|
+
self, image: tv_tensors.Image, target: torch.Tensor
|
|
84
|
+
) -> Tuple[tv_tensors.Image, torch.Tensor]:
|
|
88
85
|
"""Applies the transforms to the provided data and returns them.
|
|
89
86
|
|
|
90
87
|
Args:
|
|
@@ -94,10 +91,6 @@ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], a
|
|
|
94
91
|
Returns:
|
|
95
92
|
A tuple with the image and the target transformed.
|
|
96
93
|
"""
|
|
97
|
-
if self.
|
|
98
|
-
image = self.
|
|
99
|
-
|
|
100
|
-
if self._target_transforms is not None:
|
|
101
|
-
target = self._target_transforms(target)
|
|
102
|
-
|
|
94
|
+
if self._transforms is not None:
|
|
95
|
+
image, target = self._transforms(image, target)
|
|
103
96
|
return image, target
|