kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,6 +4,7 @@ from typing import Any, Mapping
|
|
|
4
4
|
|
|
5
5
|
import lightning.pytorch as pl
|
|
6
6
|
import torch
|
|
7
|
+
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
|
|
7
8
|
from lightning.pytorch.utilities import memory
|
|
8
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
10
|
from typing_extensions import override
|
|
@@ -46,6 +47,21 @@ class ModelModule(pl.LightningModule):
|
|
|
46
47
|
"""The default post-processes."""
|
|
47
48
|
return batch_postprocess.BatchPostProcess()
|
|
48
49
|
|
|
50
|
+
@property
|
|
51
|
+
def metrics_device(self) -> torch.device:
|
|
52
|
+
"""Returns the device by which the metrics should be calculated.
|
|
53
|
+
|
|
54
|
+
We allocate the metrics to CPU when operating on single device, as
|
|
55
|
+
it is much faster, but to GPU when employing multiple ones, as DDP
|
|
56
|
+
strategy requires the metrics to be allocated to the module's GPU.
|
|
57
|
+
"""
|
|
58
|
+
move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy)
|
|
59
|
+
return torch.device("cpu") if move_to_cpu else self.device
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def on_fit_start(self) -> None:
|
|
63
|
+
self.metrics.to(device=self.metrics_device)
|
|
64
|
+
|
|
49
65
|
@override
|
|
50
66
|
def on_train_batch_end(
|
|
51
67
|
self,
|
|
@@ -59,6 +75,10 @@ class ModelModule(pl.LightningModule):
|
|
|
59
75
|
batch_outputs=outputs,
|
|
60
76
|
)
|
|
61
77
|
|
|
78
|
+
@override
|
|
79
|
+
def on_validation_start(self) -> None:
|
|
80
|
+
self.metrics.to(device=self.metrics_device)
|
|
81
|
+
|
|
62
82
|
@override
|
|
63
83
|
def on_validation_batch_end(
|
|
64
84
|
self,
|
|
@@ -78,6 +98,10 @@ class ModelModule(pl.LightningModule):
|
|
|
78
98
|
def on_validation_epoch_end(self) -> None:
|
|
79
99
|
self._compute_and_log_metrics(self.metrics.validation_metrics)
|
|
80
100
|
|
|
101
|
+
@override
|
|
102
|
+
def on_test_start(self) -> None:
|
|
103
|
+
self.metrics.to(device=self.metrics_device)
|
|
104
|
+
|
|
81
105
|
@override
|
|
82
106
|
def on_test_batch_end(
|
|
83
107
|
self,
|
|
@@ -110,7 +134,7 @@ class ModelModule(pl.LightningModule):
|
|
|
110
134
|
The updated outputs.
|
|
111
135
|
"""
|
|
112
136
|
self._postprocess(outputs)
|
|
113
|
-
return memory.recursive_detach(outputs, to_cpu=self.
|
|
137
|
+
return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
|
|
114
138
|
|
|
115
139
|
def _forward_and_log_metrics(
|
|
116
140
|
self,
|
|
@@ -16,7 +16,20 @@ class INPUT_BATCH(NamedTuple):
|
|
|
16
16
|
data: torch.Tensor
|
|
17
17
|
"""The data batch."""
|
|
18
18
|
|
|
19
|
-
targets: torch.Tensor |
|
|
19
|
+
targets: torch.Tensor | None = None
|
|
20
|
+
"""The target batch."""
|
|
21
|
+
|
|
22
|
+
metadata: Dict[str, Any] | None = None
|
|
23
|
+
"""The associated metadata."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class INPUT_TENSOR_BATCH(NamedTuple):
|
|
27
|
+
"""Tensor based input batch data scheme."""
|
|
28
|
+
|
|
29
|
+
data: torch.Tensor
|
|
30
|
+
"""The data batch."""
|
|
31
|
+
|
|
32
|
+
targets: torch.Tensor
|
|
20
33
|
"""The target batch."""
|
|
21
34
|
|
|
22
35
|
metadata: Dict[str, Any] | None = None
|
|
@@ -2,9 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import functools
|
|
5
|
-
from typing import Callable, List
|
|
5
|
+
from typing import Any, Callable, Dict, List
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
from jsonargparse import _util
|
|
8
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
10
|
|
|
10
11
|
Transform = Callable[[torch.Tensor], torch.Tensor]
|
|
@@ -15,10 +16,10 @@ Transform = Callable[[torch.Tensor], torch.Tensor]
|
|
|
15
16
|
class BatchPostProcess:
|
|
16
17
|
"""Batch post-processes transform schema."""
|
|
17
18
|
|
|
18
|
-
targets_transforms: List[Transform] | None = None
|
|
19
|
+
targets_transforms: List[Transform | Dict[str, Any]] | None = None
|
|
19
20
|
"""Holds the common train and evaluation metrics."""
|
|
20
21
|
|
|
21
|
-
predictions_transforms: List[Transform] | None = None
|
|
22
|
+
predictions_transforms: List[Transform | Dict[str, Any]] | None = None
|
|
22
23
|
"""Holds the common train and evaluation metrics."""
|
|
23
24
|
|
|
24
25
|
def __call__(self, outputs: STEP_OUTPUT) -> None:
|
|
@@ -35,12 +36,13 @@ class BatchPostProcess:
|
|
|
35
36
|
|
|
36
37
|
if "targets" in outputs and self.targets_transforms is not None:
|
|
37
38
|
outputs["targets"] = _apply_transforms(
|
|
38
|
-
outputs["targets"], transforms=self.targets_transforms
|
|
39
|
+
outputs["targets"], transforms=_parse_callable_inputs(self.targets_transforms)
|
|
39
40
|
)
|
|
40
41
|
|
|
41
42
|
if "predictions" in outputs and self.predictions_transforms is not None:
|
|
42
43
|
outputs["predictions"] = _apply_transforms(
|
|
43
|
-
outputs["predictions"],
|
|
44
|
+
outputs["predictions"],
|
|
45
|
+
transforms=_parse_callable_inputs(self.predictions_transforms),
|
|
44
46
|
)
|
|
45
47
|
|
|
46
48
|
|
|
@@ -55,3 +57,33 @@ def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torc
|
|
|
55
57
|
The processed tensor.
|
|
56
58
|
"""
|
|
57
59
|
return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _parse_callable_inputs(inputs: List[Callable | Dict[str, Any]]) -> List[Callable]:
|
|
63
|
+
"""Parses the inputs which where passed as dictionary to callable objects."""
|
|
64
|
+
parsed = []
|
|
65
|
+
for item in inputs:
|
|
66
|
+
if isinstance(item, dict):
|
|
67
|
+
item = _parse_dict(item)
|
|
68
|
+
parsed.append(item)
|
|
69
|
+
return parsed
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_dict(item: Dict[str, Any]) -> Callable:
|
|
73
|
+
"""Parses the input dictionary to a partial callable object."""
|
|
74
|
+
if not _is_valid_dict(item):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"Transform dictionary format is not valid. "
|
|
77
|
+
"It must contain a key 'class_path' and optionally 'init_args' for "
|
|
78
|
+
"the function and additional call arguments."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return functools.partial(
|
|
82
|
+
_util.import_object(item["class_path"]),
|
|
83
|
+
**item.get("init_args", {}),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _is_valid_dict(item: Dict[str, Any], /) -> bool:
|
|
88
|
+
"""Checks if the input has the valid structure."""
|
|
89
|
+
return "class_path" in item and set(item.keys()) <= {"class_path", "init_args"}
|
eva/core/models/networks/mlp.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Multi-layer Perceptron (MLP) implemented in PyTorch."""
|
|
2
2
|
|
|
3
|
-
from typing import Type
|
|
3
|
+
from typing import Tuple, Type
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
@@ -13,7 +13,7 @@ class MLP(nn.Module):
|
|
|
13
13
|
self,
|
|
14
14
|
input_size: int,
|
|
15
15
|
output_size: int,
|
|
16
|
-
hidden_layer_sizes:
|
|
16
|
+
hidden_layer_sizes: Tuple[int, ...] | None = None,
|
|
17
17
|
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
|
|
18
18
|
output_activation_fn: Type[torch.nn.Module] | None = None,
|
|
19
19
|
dropout: float = 0.0,
|
|
@@ -7,6 +7,14 @@ from transformers import modeling_outputs
|
|
|
7
7
|
class ExtractCLSFeatures:
|
|
8
8
|
"""Extracts the CLS token from a ViT model output."""
|
|
9
9
|
|
|
10
|
+
def __init__(self, cls_index: int = 0) -> None:
|
|
11
|
+
"""Initializes the transformation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
cls_index: The index of the CLS token in the output tensor.
|
|
15
|
+
"""
|
|
16
|
+
self._cls_index = cls_index
|
|
17
|
+
|
|
10
18
|
def __call__(
|
|
11
19
|
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
12
20
|
) -> torch.Tensor:
|
|
@@ -16,9 +24,9 @@ class ExtractCLSFeatures:
|
|
|
16
24
|
tensor: The tensor representing the model output.
|
|
17
25
|
"""
|
|
18
26
|
if isinstance(tensor, torch.Tensor):
|
|
19
|
-
transformed_tensor = tensor[:,
|
|
27
|
+
transformed_tensor = tensor[:, self._cls_index, :]
|
|
20
28
|
elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
21
|
-
transformed_tensor = tensor.last_hidden_state[:,
|
|
29
|
+
transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :]
|
|
22
30
|
else:
|
|
23
31
|
raise ValueError(f"Unsupported type {type(tensor)}")
|
|
24
32
|
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Transforms for extracting the patch features from a model output."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from transformers import modeling_outputs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExtractPatchFeatures:
|
|
11
|
+
"""Extracts the patch features from a ViT model output."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, ignore_remaining_dims: bool = False) -> None:
|
|
14
|
+
"""Initializes the transformation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
|
|
18
|
+
of the patch grid if it is not a square number.
|
|
19
|
+
"""
|
|
20
|
+
self._ignore_remaining_dims = ignore_remaining_dims
|
|
21
|
+
|
|
22
|
+
def __call__(
|
|
23
|
+
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
24
|
+
) -> List[torch.Tensor]:
|
|
25
|
+
"""Call method for the transformation.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
tensor: The raw embeddings of the model.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
|
|
32
|
+
representing the model output.
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
35
|
+
features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1)
|
|
36
|
+
batch_size, hidden_size, patch_grid = features.shape
|
|
37
|
+
height = width = int(math.sqrt(patch_grid))
|
|
38
|
+
if height * width != patch_grid:
|
|
39
|
+
if self._ignore_remaining_dims:
|
|
40
|
+
features = features[:, :, : height * width]
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
|
|
43
|
+
patch_embeddings = features.view(batch_size, hidden_size, height, width)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unsupported type {type(tensor)}")
|
|
46
|
+
|
|
47
|
+
return [patch_embeddings]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Model Wrappers API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.models.wrappers.base import BaseModel
|
|
4
|
+
from eva.core.models.wrappers.from_function import ModelFromFunction
|
|
5
|
+
from eva.core.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.core.models.wrappers.onnx import ONNXModel
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BaseModel",
|
|
10
|
+
"ModelFromFunction",
|
|
11
|
+
"HuggingFaceModel",
|
|
12
|
+
"ONNXModel",
|
|
13
|
+
]
|
|
@@ -22,6 +22,8 @@ class BaseModel(nn.Module):
|
|
|
22
22
|
|
|
23
23
|
self._output_transforms = tensor_transforms
|
|
24
24
|
|
|
25
|
+
self._model: Callable[..., torch.Tensor] | nn.Module
|
|
26
|
+
|
|
25
27
|
@override
|
|
26
28
|
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
27
29
|
tensor = self.model_forward(tensor)
|
|
@@ -32,14 +34,13 @@ class BaseModel(nn.Module):
|
|
|
32
34
|
"""Loads the model."""
|
|
33
35
|
raise NotImplementedError
|
|
34
36
|
|
|
35
|
-
@abc.abstractmethod
|
|
36
37
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
37
38
|
"""Implements the forward pass of the model.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
41
|
tensor: The input tensor to the model.
|
|
41
42
|
"""
|
|
42
|
-
|
|
43
|
+
return self._model(tensor)
|
|
43
44
|
|
|
44
45
|
def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
45
46
|
if self._output_transforms is not None:
|
|
@@ -3,12 +3,10 @@
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import jsonargparse
|
|
6
|
-
import torch
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
10
|
-
from eva.core.models.
|
|
11
|
-
from eva.core.models.networks.wrappers import base
|
|
9
|
+
from eva.core.models.wrappers import _utils, base
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
class ModelFromFunction(base.BaseModel):
|
|
@@ -36,23 +34,18 @@ class ModelFromFunction(base.BaseModel):
|
|
|
36
34
|
tensor_transforms: The transforms to apply to the output tensor
|
|
37
35
|
produced by the model.
|
|
38
36
|
"""
|
|
39
|
-
super().__init__()
|
|
37
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
40
38
|
|
|
41
39
|
self._path = path
|
|
42
40
|
self._arguments = arguments
|
|
43
41
|
self._checkpoint_path = checkpoint_path
|
|
44
|
-
self._tensor_transforms = tensor_transforms
|
|
45
42
|
|
|
46
|
-
self.
|
|
43
|
+
self.load_model()
|
|
47
44
|
|
|
48
45
|
@override
|
|
49
|
-
def load_model(self) ->
|
|
46
|
+
def load_model(self) -> None:
|
|
50
47
|
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
|
|
51
48
|
model = class_path(**self._arguments or {})
|
|
52
49
|
if self._checkpoint_path is not None:
|
|
53
50
|
_utils.load_model_weights(model, self._checkpoint_path)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@override
|
|
57
|
-
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
58
|
-
return self._model(tensor)
|
|
51
|
+
self._model = model
|
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Wrappers for HuggingFace `transformers` models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
import transformers
|
|
7
6
|
from typing_extensions import override
|
|
8
7
|
|
|
9
|
-
from eva.core.models.
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
class HuggingFaceModel(base.BaseModel):
|
|
13
12
|
"""Wrapper class for loading HuggingFace `transformers` models."""
|
|
14
13
|
|
|
15
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model_name_or_path: str,
|
|
17
|
+
tensor_transforms: Callable | None = None,
|
|
18
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
19
|
+
) -> None:
|
|
16
20
|
"""Initializes the model.
|
|
17
21
|
|
|
18
22
|
Args:
|
|
@@ -21,17 +25,17 @@ class HuggingFaceModel(base.BaseModel):
|
|
|
21
25
|
model hub.
|
|
22
26
|
tensor_transforms: The transforms to apply to the output tensor
|
|
23
27
|
produced by the model.
|
|
28
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
24
29
|
"""
|
|
25
30
|
super().__init__(tensor_transforms=tensor_transforms)
|
|
26
31
|
|
|
27
32
|
self._model_name_or_path = model_name_or_path
|
|
28
|
-
self.
|
|
33
|
+
self._model_kwargs = model_kwargs or {}
|
|
29
34
|
|
|
30
|
-
|
|
31
|
-
def load_model(self) -> Any:
|
|
32
|
-
config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
|
|
33
|
-
return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
|
|
35
|
+
self.load_model()
|
|
34
36
|
|
|
35
37
|
@override
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
+
def load_model(self) -> None:
|
|
39
|
+
self._model = transformers.AutoModel.from_pretrained(
|
|
40
|
+
self._model_name_or_path, **self._model_kwargs
|
|
41
|
+
)
|
|
@@ -6,7 +6,7 @@ import onnxruntime as ort
|
|
|
6
6
|
import torch
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from eva.core.models.
|
|
9
|
+
from eva.core.models.wrappers import base
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class ONNXModel(base.BaseModel):
|
|
@@ -29,19 +29,22 @@ class ONNXModel(base.BaseModel):
|
|
|
29
29
|
|
|
30
30
|
self._path = path
|
|
31
31
|
self._device = device
|
|
32
|
-
|
|
32
|
+
|
|
33
|
+
self.load_model()
|
|
33
34
|
|
|
34
35
|
@override
|
|
35
36
|
def load_model(self) -> Any:
|
|
36
37
|
if self._device == "cuda" and not torch.cuda.is_available():
|
|
37
38
|
raise ValueError("Device is set to 'cuda', but CUDA is not available.")
|
|
38
39
|
provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
|
|
39
|
-
|
|
40
|
+
self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
|
|
40
41
|
|
|
41
42
|
@override
|
|
42
43
|
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
43
44
|
# TODO: Use IO binding to avoid copying the tensor to CPU.
|
|
44
45
|
# https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
|
|
46
|
+
if not isinstance(self._model, ort.InferenceSession):
|
|
47
|
+
raise ValueError("Model is not loaded.")
|
|
45
48
|
inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
|
|
46
49
|
outputs = self._model.run(None, inputs)[0]
|
|
47
50
|
return torch.from_numpy(outputs).float().to(tensor.device)
|
eva/core/trainers/_recorder.py
CHANGED
|
@@ -5,18 +5,41 @@ import json
|
|
|
5
5
|
import os
|
|
6
6
|
import statistics
|
|
7
7
|
import sys
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Dict, List, Mapping, TypedDict
|
|
9
9
|
|
|
10
10
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
|
11
11
|
from lightning_fabric.utilities import cloud_io
|
|
12
12
|
from loguru import logger
|
|
13
13
|
from omegaconf import OmegaConf
|
|
14
|
+
from rich import console as rich_console
|
|
15
|
+
from rich import table as rich_table
|
|
14
16
|
from toolz import dicttoolz
|
|
15
17
|
|
|
16
18
|
SESSION_METRICS = Mapping[str, List[float]]
|
|
17
19
|
"""Session metrics type-hint."""
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
class SESSION_STATISTICS(TypedDict):
|
|
23
|
+
"""Type-hint for aggregated metrics of multiple runs with mean & stdev."""
|
|
24
|
+
|
|
25
|
+
mean: float
|
|
26
|
+
stdev: float
|
|
27
|
+
values: List[float]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class STAGE_RESULTS(TypedDict):
|
|
31
|
+
"""Type-hint for metrics statstics for val & test stages."""
|
|
32
|
+
|
|
33
|
+
val: List[Dict[str, SESSION_STATISTICS]]
|
|
34
|
+
test: List[Dict[str, SESSION_STATISTICS]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RESULTS_DICT(TypedDict):
|
|
38
|
+
"""Type-hint for the final results dictionary."""
|
|
39
|
+
|
|
40
|
+
metrics: STAGE_RESULTS
|
|
41
|
+
|
|
42
|
+
|
|
20
43
|
class SessionRecorder:
|
|
21
44
|
"""Multi-run (session) summary logger."""
|
|
22
45
|
|
|
@@ -25,6 +48,7 @@ class SessionRecorder:
|
|
|
25
48
|
output_dir: str,
|
|
26
49
|
results_file: str = "results.json",
|
|
27
50
|
config_file: str = "config.yaml",
|
|
51
|
+
verbose: bool = True,
|
|
28
52
|
) -> None:
|
|
29
53
|
"""Initializes the recorder.
|
|
30
54
|
|
|
@@ -32,10 +56,12 @@ class SessionRecorder:
|
|
|
32
56
|
output_dir: The destination folder to save the results.
|
|
33
57
|
results_file: The name of the results json file.
|
|
34
58
|
config_file: The name of the yaml configuration file.
|
|
59
|
+
verbose: Whether to print the session metrics.
|
|
35
60
|
"""
|
|
36
61
|
self._output_dir = output_dir
|
|
37
62
|
self._results_file = results_file
|
|
38
63
|
self._config_file = config_file
|
|
64
|
+
self._verbose = verbose
|
|
39
65
|
|
|
40
66
|
self._validation_metrics: List[SESSION_METRICS] = []
|
|
41
67
|
self._test_metrics: List[SESSION_METRICS] = []
|
|
@@ -67,13 +93,13 @@ class SessionRecorder:
|
|
|
67
93
|
self._update_validation_metrics(validation_scores)
|
|
68
94
|
self._update_test_metrics(test_scores)
|
|
69
95
|
|
|
70
|
-
def compute(self) ->
|
|
96
|
+
def compute(self) -> STAGE_RESULTS:
|
|
71
97
|
"""Computes and returns the session statistics."""
|
|
72
98
|
validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
|
|
73
99
|
test_statistics = list(map(_calculate_statistics, self._test_metrics))
|
|
74
100
|
return {"val": validation_statistics, "test": test_statistics}
|
|
75
101
|
|
|
76
|
-
def export(self) ->
|
|
102
|
+
def export(self) -> RESULTS_DICT:
|
|
77
103
|
"""Exports the results."""
|
|
78
104
|
statistics = self.compute()
|
|
79
105
|
return {"metrics": statistics}
|
|
@@ -83,6 +109,8 @@ class SessionRecorder:
|
|
|
83
109
|
results = self.export()
|
|
84
110
|
_save_json(results, self.filename)
|
|
85
111
|
self._save_config()
|
|
112
|
+
if self._verbose:
|
|
113
|
+
_print_results(results)
|
|
86
114
|
|
|
87
115
|
def reset(self) -> None:
|
|
88
116
|
"""Resets the state of the tracked metrics."""
|
|
@@ -125,10 +153,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
|
|
|
125
153
|
return [collections.defaultdict(list) for _ in range(n_datasets)]
|
|
126
154
|
|
|
127
155
|
|
|
128
|
-
def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str,
|
|
156
|
+
def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]:
|
|
129
157
|
"""Calculate the metric statistics of a dataset session run."""
|
|
130
158
|
|
|
131
|
-
def _calculate_metric_statistics(values: List[float]) ->
|
|
159
|
+
def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS:
|
|
132
160
|
"""Calculates and returns the metric statistics."""
|
|
133
161
|
mean = statistics.mean(values)
|
|
134
162
|
stdev = statistics.stdev(values) if len(values) > 1 else 0
|
|
@@ -137,7 +165,7 @@ def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float |
|
|
|
137
165
|
return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)
|
|
138
166
|
|
|
139
167
|
|
|
140
|
-
def _save_json(data:
|
|
168
|
+
def _save_json(data: RESULTS_DICT, save_as: str = "data.json"):
|
|
141
169
|
"""Saves data to a json file."""
|
|
142
170
|
if not save_as.endswith(".json"):
|
|
143
171
|
raise ValueError()
|
|
@@ -146,4 +174,38 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
|
|
|
146
174
|
fs = cloud_io.get_filesystem(output_dir, anon=False)
|
|
147
175
|
fs.makedirs(output_dir, exist_ok=True)
|
|
148
176
|
with fs.open(save_as, "w") as file:
|
|
149
|
-
json.dump(data, file, indent=
|
|
177
|
+
json.dump(data, file, indent=2, sort_keys=True)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _print_results(results: RESULTS_DICT) -> None:
|
|
181
|
+
"""Prints the results to the console."""
|
|
182
|
+
try:
|
|
183
|
+
for stage in ["val", "test"]:
|
|
184
|
+
for dataset_idx in range(len(results["metrics"][stage])):
|
|
185
|
+
_print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"Failed to print the results: {e}")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int):
|
|
191
|
+
"""Prints the metrics of a single dataset as a table."""
|
|
192
|
+
metrics_table = rich_table.Table(
|
|
193
|
+
title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold"
|
|
194
|
+
)
|
|
195
|
+
metrics_table.add_column("Metric", style="cyan")
|
|
196
|
+
metrics_table.add_column("Mean", style="magenta")
|
|
197
|
+
metrics_table.add_column("Stdev", style="magenta")
|
|
198
|
+
metrics_table.add_column("All", style="magenta")
|
|
199
|
+
|
|
200
|
+
n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"])
|
|
201
|
+
for metric_name, metric_dict in metrics_dict.items():
|
|
202
|
+
row = [
|
|
203
|
+
metric_name,
|
|
204
|
+
f'{metric_dict["mean"]:.3f}',
|
|
205
|
+
f'{metric_dict["stdev"]:.3f}',
|
|
206
|
+
", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)),
|
|
207
|
+
]
|
|
208
|
+
metrics_table.add_row(*row)
|
|
209
|
+
|
|
210
|
+
console = rich_console.Console()
|
|
211
|
+
console.print(metrics_table)
|
eva/core/trainers/functional.py
CHANGED
|
@@ -16,6 +16,7 @@ def run_evaluation_session(
|
|
|
16
16
|
datamodule: datamodules.DataModule,
|
|
17
17
|
*,
|
|
18
18
|
n_runs: int = 1,
|
|
19
|
+
verbose: bool = True,
|
|
19
20
|
) -> None:
|
|
20
21
|
"""Runs a downstream evaluation session out-of-place.
|
|
21
22
|
|
|
@@ -29,11 +30,17 @@ def run_evaluation_session(
|
|
|
29
30
|
base_model: The base model module to use.
|
|
30
31
|
datamodule: The data module.
|
|
31
32
|
n_runs: The amount of runs (fit and evaluate) to perform.
|
|
33
|
+
verbose: Whether to verbose the session metrics instead of
|
|
34
|
+
these of each individual runs and vice-versa.
|
|
32
35
|
"""
|
|
33
|
-
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
|
|
36
|
+
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
|
|
34
37
|
for run_index in range(n_runs):
|
|
35
38
|
validation_scores, test_scores = run_evaluation(
|
|
36
|
-
base_trainer,
|
|
39
|
+
base_trainer,
|
|
40
|
+
base_model,
|
|
41
|
+
datamodule,
|
|
42
|
+
run_id=f"run_{run_index}",
|
|
43
|
+
verbose=not verbose,
|
|
37
44
|
)
|
|
38
45
|
recorder.update(validation_scores, test_scores)
|
|
39
46
|
recorder.save()
|
|
@@ -45,6 +52,7 @@ def run_evaluation(
|
|
|
45
52
|
datamodule: datamodules.DataModule,
|
|
46
53
|
*,
|
|
47
54
|
run_id: str | None = None,
|
|
55
|
+
verbose: bool = True,
|
|
48
56
|
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
49
57
|
"""Fits and evaluates a model out-of-place.
|
|
50
58
|
|
|
@@ -54,19 +62,23 @@ def run_evaluation(
|
|
|
54
62
|
datamodule: The data module.
|
|
55
63
|
run_id: The run id to be appended to the output log directory.
|
|
56
64
|
If `None`, it will use the log directory of the trainer as is.
|
|
65
|
+
verbose: Whether to print the validation and test metrics
|
|
66
|
+
in the end of the training.
|
|
57
67
|
|
|
58
68
|
Returns:
|
|
59
69
|
A tuple of with the validation and the test metrics (if exists).
|
|
60
70
|
"""
|
|
61
71
|
trainer, model = _utils.clone(base_trainer, base_model)
|
|
72
|
+
model.configure_model()
|
|
62
73
|
trainer.setup_log_dirs(run_id or "")
|
|
63
|
-
return fit_and_validate(trainer, model, datamodule)
|
|
74
|
+
return fit_and_validate(trainer, model, datamodule, verbose=verbose)
|
|
64
75
|
|
|
65
76
|
|
|
66
77
|
def fit_and_validate(
|
|
67
78
|
trainer: eva_trainer.Trainer,
|
|
68
79
|
model: modules.ModelModule,
|
|
69
80
|
datamodule: datamodules.DataModule,
|
|
81
|
+
verbose: bool = True,
|
|
70
82
|
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
71
83
|
"""Fits and evaluates a model in-place.
|
|
72
84
|
|
|
@@ -77,13 +89,19 @@ def fit_and_validate(
|
|
|
77
89
|
trainer: The trainer module to use and update in-place.
|
|
78
90
|
model: The model module to use and update in-place.
|
|
79
91
|
datamodule: The data module.
|
|
92
|
+
verbose: Whether to print the validation and test metrics
|
|
93
|
+
in the end of the training.
|
|
80
94
|
|
|
81
95
|
Returns:
|
|
82
96
|
A tuple of with the validation and the test metrics (if exists).
|
|
83
97
|
"""
|
|
84
98
|
trainer.fit(model, datamodule=datamodule)
|
|
85
|
-
validation_scores = trainer.validate(datamodule=datamodule)
|
|
86
|
-
test_scores =
|
|
99
|
+
validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose)
|
|
100
|
+
test_scores = (
|
|
101
|
+
None
|
|
102
|
+
if datamodule.datasets.test is None
|
|
103
|
+
else trainer.test(datamodule=datamodule, verbose=verbose)
|
|
104
|
+
)
|
|
87
105
|
return validation_scores, test_scores
|
|
88
106
|
|
|
89
107
|
|