kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Core DataModule."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import lightning.pytorch as pl
|
|
6
|
+
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.data import dataloaders as dataloaders_lib
|
|
10
|
+
from eva.core.data import datasets as datasets_lib
|
|
11
|
+
from eva.core.data.datamodules import call, schemas
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DataModule(pl.LightningDataModule):
|
|
15
|
+
"""DataModule encapsulates all the steps needed to process data.
|
|
16
|
+
|
|
17
|
+
It will initialize and create the mapping between dataloaders and
|
|
18
|
+
datasets. During the `prepare_data`, `setup` and `teardown`, the
|
|
19
|
+
datamodule will call the respective methods from all datasets,
|
|
20
|
+
given that they are defined.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
datasets: schemas.DatasetsSchema | None = None,
|
|
26
|
+
dataloaders: schemas.DataloadersSchema | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Initializes the datamodule.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
datasets: The desired datasets.
|
|
32
|
+
dataloaders: The desired dataloaders.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
self.datasets = datasets or self.default_datasets
|
|
37
|
+
self.dataloaders = dataloaders or self.default_dataloaders
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def default_datasets(self) -> schemas.DatasetsSchema:
|
|
41
|
+
"""Returns the default datasets."""
|
|
42
|
+
return schemas.DatasetsSchema()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def default_dataloaders(self) -> schemas.DataloadersSchema:
|
|
46
|
+
"""Returns the default dataloader schema."""
|
|
47
|
+
return schemas.DataloadersSchema()
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def prepare_data(self) -> None:
|
|
51
|
+
call.call_method_if_exists(self.datasets.tolist(), "prepare_data")
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def setup(self, stage: str) -> None:
|
|
55
|
+
call.call_method_if_exists(self.datasets.tolist(stage), "setup")
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def teardown(self, stage: str) -> None:
|
|
59
|
+
call.call_method_if_exists(self.datasets.tolist(stage), "teardown")
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
|
63
|
+
if self.datasets.train is None:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"Train dataloader can not be initialized as `self.datasets.train` is `None`."
|
|
66
|
+
)
|
|
67
|
+
return self.dataloaders.train(self.datasets.train)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
|
71
|
+
if self.datasets.val is None:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Validation dataloader can not be initialized as `self.datasets.val` is `None`."
|
|
74
|
+
)
|
|
75
|
+
return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val)
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def test_dataloader(self) -> EVAL_DATALOADERS:
|
|
79
|
+
if self.datasets.test is None:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"Test dataloader can not be initialized as `self.datasets.test` is `None`."
|
|
82
|
+
)
|
|
83
|
+
return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test)
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
|
87
|
+
if self.datasets.predict is None:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Predict dataloader can not be initialized as `self.datasets.predict` is `None`."
|
|
90
|
+
)
|
|
91
|
+
return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict)
|
|
92
|
+
|
|
93
|
+
def _initialize_dataloaders(
|
|
94
|
+
self,
|
|
95
|
+
dataloader: dataloaders_lib.DataLoader,
|
|
96
|
+
datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset],
|
|
97
|
+
) -> EVAL_DATALOADERS:
|
|
98
|
+
"""Initializes dataloaders from a given set of dataset.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
dataloader: The dataloader to apply to the provided datasets.
|
|
102
|
+
datasets: The desired dataset(s) to allocate dataloader(s).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
A list with the dataloaders of the provided dataset(s).
|
|
106
|
+
"""
|
|
107
|
+
datasets = datasets if isinstance(datasets, list) else [datasets]
|
|
108
|
+
return list(map(dataloader, datasets))
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Argument schemas used in DataModule."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from eva.core.data import dataloaders, datasets
|
|
7
|
+
|
|
8
|
+
TRAIN_DATASET = datasets.TorchDataset | None
|
|
9
|
+
"""Train dataset."""
|
|
10
|
+
|
|
11
|
+
EVAL_DATASET = datasets.TorchDataset | List[datasets.TorchDataset] | None
|
|
12
|
+
"""Evaluation dataset."""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclasses.dataclass(frozen=True)
|
|
16
|
+
class DatasetsSchema:
|
|
17
|
+
"""Datasets schema used in DataModule."""
|
|
18
|
+
|
|
19
|
+
train: TRAIN_DATASET = None
|
|
20
|
+
"""Train dataset."""
|
|
21
|
+
|
|
22
|
+
val: EVAL_DATASET = None
|
|
23
|
+
"""Validation dataset."""
|
|
24
|
+
|
|
25
|
+
test: EVAL_DATASET = None
|
|
26
|
+
"""Test dataset."""
|
|
27
|
+
|
|
28
|
+
predict: EVAL_DATASET = None
|
|
29
|
+
"""Predict dataset."""
|
|
30
|
+
|
|
31
|
+
def tolist(self, stage: str | None = None) -> List[EVAL_DATASET]:
|
|
32
|
+
"""Returns the dataclass as a list and optionally filters it given the stage."""
|
|
33
|
+
match stage:
|
|
34
|
+
case "fit":
|
|
35
|
+
return [self.train, self.val]
|
|
36
|
+
case "validate":
|
|
37
|
+
return [self.val]
|
|
38
|
+
case "test":
|
|
39
|
+
return [self.test]
|
|
40
|
+
case "predict":
|
|
41
|
+
return [self.predict]
|
|
42
|
+
case None:
|
|
43
|
+
return [self.train, self.val, self.test, self.predict]
|
|
44
|
+
case _:
|
|
45
|
+
raise ValueError(f"Invalid stage `{stage}`.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass(frozen=True)
|
|
49
|
+
class DataloadersSchema:
|
|
50
|
+
"""Dataloaders schema used in DataModule."""
|
|
51
|
+
|
|
52
|
+
train: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
|
|
53
|
+
"""Train dataloader."""
|
|
54
|
+
|
|
55
|
+
val: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
|
|
56
|
+
"""Validation dataloader."""
|
|
57
|
+
|
|
58
|
+
test: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
|
|
59
|
+
"""Test dataloader."""
|
|
60
|
+
|
|
61
|
+
predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
|
|
62
|
+
"""Predict dataloader."""
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Datasets API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.data.datasets.base import Dataset
|
|
4
|
+
from eva.core.data.datasets.classification import EmbeddingsClassificationDataset
|
|
5
|
+
from eva.core.data.datasets.dataset import TorchDataset
|
|
6
|
+
|
|
7
|
+
__all__ = ["Dataset", "EmbeddingsClassificationDataset", "TorchDataset"]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Base dataset class."""
|
|
2
|
+
|
|
3
|
+
from eva.core.data.datasets import dataset
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Dataset(dataset.TorchDataset):
|
|
7
|
+
"""Base dataset class."""
|
|
8
|
+
|
|
9
|
+
def prepare_data(self) -> None:
|
|
10
|
+
"""Encapsulates all disk related tasks.
|
|
11
|
+
|
|
12
|
+
This method is preferred for downloading and preparing the data, for
|
|
13
|
+
example generate manifest files. If implemented, it will be called via
|
|
14
|
+
:class:`eva.core.data.datamodules.DataModule`, which ensures that is called
|
|
15
|
+
only within a single process, making it multi-processes safe.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def setup(self) -> None:
|
|
19
|
+
"""Setups the dataset.
|
|
20
|
+
|
|
21
|
+
This method is preferred for creating datasets or performing
|
|
22
|
+
train/val/test splits. If implemented, it will be called via
|
|
23
|
+
:class:`eva.core.data.datamodules.DataModule` at the beginning of fit
|
|
24
|
+
(train + validate), validate, test, or predict and it will be called
|
|
25
|
+
from every process (i.e. GPU) across all the nodes in DDP.
|
|
26
|
+
"""
|
|
27
|
+
self.configure()
|
|
28
|
+
self.validate()
|
|
29
|
+
|
|
30
|
+
def configure(self):
|
|
31
|
+
"""Configures the dataset.
|
|
32
|
+
|
|
33
|
+
This method is preferred to configure the dataset; assign values
|
|
34
|
+
to attributes, perform splits etc. This would be called from the
|
|
35
|
+
method ::method::`setup`, before calling the ::method::`validate`.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def validate(self):
|
|
39
|
+
"""Validates the dataset.
|
|
40
|
+
|
|
41
|
+
This method aims to check the integrity of the dataset and verify
|
|
42
|
+
that is configured properly. This would be called from the method
|
|
43
|
+
::method::`setup`, after calling the ::method::`configure`.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def teardown(self) -> None:
|
|
47
|
+
"""Cleans up the data artifacts.
|
|
48
|
+
|
|
49
|
+
Used to clean-up when the run is finished. If implemented, it will
|
|
50
|
+
be called via :class:`eva.core.data.datamodules.DataModule` at the end
|
|
51
|
+
of fit (train + validate), validate, test, or predict and it will be
|
|
52
|
+
called from every process (i.e. GPU) across all the nodes in DDP.
|
|
53
|
+
"""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Embeddings classification dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.core.data.datasets import base
|
|
12
|
+
from eva.core.utils import io
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingsClassificationDataset(base.Dataset):
|
|
16
|
+
"""Embeddings classification dataset."""
|
|
17
|
+
|
|
18
|
+
default_column_mapping: Dict[str, str] = {
|
|
19
|
+
"data": "embeddings",
|
|
20
|
+
"target": "target",
|
|
21
|
+
"split": "split",
|
|
22
|
+
}
|
|
23
|
+
"""The default column mapping of the variables to the manifest columns."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
root: str,
|
|
28
|
+
manifest_file: str,
|
|
29
|
+
split: str | None = None,
|
|
30
|
+
column_mapping: Dict[str, str] = default_column_mapping,
|
|
31
|
+
embeddings_transforms: Callable | None = None,
|
|
32
|
+
target_transforms: Callable | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize dataset.
|
|
35
|
+
|
|
36
|
+
Expects a manifest file listing the paths of .pt files that contain
|
|
37
|
+
tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
root: Root directory of the dataset.
|
|
41
|
+
manifest_file: The path to the manifest file, which is relative to
|
|
42
|
+
the `root` argument.
|
|
43
|
+
split: The dataset split to use. The `split` column of the manifest
|
|
44
|
+
file will be splitted based on this value.
|
|
45
|
+
column_mapping: Defines the map between the variables and the manifest
|
|
46
|
+
columns. It will overwrite the `default_column_mapping` with
|
|
47
|
+
the provided values, so that `column_mapping` can contain only the
|
|
48
|
+
values which are altered or missing.
|
|
49
|
+
embeddings_transforms: A function/transform that transforms the embedding.
|
|
50
|
+
target_transforms: A function/transform that transforms the target.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
|
|
54
|
+
self._root = root
|
|
55
|
+
self._manifest_file = manifest_file
|
|
56
|
+
self._split = split
|
|
57
|
+
self._column_mapping = self.default_column_mapping | column_mapping
|
|
58
|
+
self._embeddings_transforms = embeddings_transforms
|
|
59
|
+
self._target_transforms = target_transforms
|
|
60
|
+
|
|
61
|
+
self._data: pd.DataFrame
|
|
62
|
+
|
|
63
|
+
def filename(self, index: int) -> str:
|
|
64
|
+
"""Returns the filename of the `index`'th data sample.
|
|
65
|
+
|
|
66
|
+
Note that this is the relative file path to the root.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
index: The index of the data-sample to select.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The filename of the `index`'th data sample.
|
|
73
|
+
"""
|
|
74
|
+
return self._data.at[index, self._column_mapping["data"]]
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def setup(self):
|
|
78
|
+
self._data = self._load_manifest()
|
|
79
|
+
|
|
80
|
+
def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
|
|
81
|
+
"""Returns the `index`'th data sample.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
index: The index of the data-sample to select.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A data sample and its target.
|
|
88
|
+
"""
|
|
89
|
+
embeddings = self._load_embeddings(index)
|
|
90
|
+
target = self._load_target(index)
|
|
91
|
+
return self._apply_transforms(embeddings, target)
|
|
92
|
+
|
|
93
|
+
def __len__(self) -> int:
|
|
94
|
+
"""Returns the total length of the data."""
|
|
95
|
+
return len(self._data)
|
|
96
|
+
|
|
97
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
98
|
+
"""Returns the `index`'th embedding sample.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
index: The index of the data sample to load.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
The sample embedding as an array.
|
|
105
|
+
"""
|
|
106
|
+
filename = self.filename(index)
|
|
107
|
+
embeddings_path = os.path.join(self._root, filename)
|
|
108
|
+
tensor = torch.load(embeddings_path, map_location="cpu")
|
|
109
|
+
return tensor.squeeze(0)
|
|
110
|
+
|
|
111
|
+
def _load_target(self, index: int) -> np.ndarray:
|
|
112
|
+
"""Returns the `index`'th target sample.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
index: The index of the data sample to load.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
The sample target as an array.
|
|
119
|
+
"""
|
|
120
|
+
target = self._data.at[index, self._column_mapping["target"]]
|
|
121
|
+
return np.asarray(target, dtype=np.int64)
|
|
122
|
+
|
|
123
|
+
def _load_manifest(self) -> pd.DataFrame:
|
|
124
|
+
"""Loads manifest file and filters the data based on the split column.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
The data as a pandas DataFrame.
|
|
128
|
+
"""
|
|
129
|
+
manifest_path = os.path.join(self._root, self._manifest_file)
|
|
130
|
+
data = io.read_dataframe(manifest_path)
|
|
131
|
+
if self._split is not None:
|
|
132
|
+
filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
|
|
133
|
+
data = filtered_data.reset_index(drop=True)
|
|
134
|
+
return data
|
|
135
|
+
|
|
136
|
+
def _apply_transforms(
|
|
137
|
+
self, embeddings: torch.Tensor, target: np.ndarray
|
|
138
|
+
) -> Tuple[torch.Tensor, np.ndarray]:
|
|
139
|
+
"""Applies the transforms to the provided data and returns them.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
embeddings: The embeddings to be transformed.
|
|
143
|
+
target: The training target.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A tuple with the embeddings and the target transformed.
|
|
147
|
+
"""
|
|
148
|
+
if self._embeddings_transforms is not None:
|
|
149
|
+
embeddings = self._embeddings_transforms(embeddings)
|
|
150
|
+
|
|
151
|
+
if self._target_transforms is not None:
|
|
152
|
+
target = self._target_transforms(target)
|
|
153
|
+
|
|
154
|
+
return embeddings, target
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Transformations to convert numpy arrays to torch tensors."""
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ArrayToTensor:
|
|
8
|
+
"""Converts a numpy array to a torch tensor."""
|
|
9
|
+
|
|
10
|
+
def __call__(self, array: npt.ArrayLike) -> torch.Tensor:
|
|
11
|
+
"""Call method for the transformation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
array: The input numpy array.
|
|
15
|
+
"""
|
|
16
|
+
return torch.from_numpy(array)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ArrayToFloatTensor(ArrayToTensor):
|
|
20
|
+
"""Converts a numpy array to a torch tensor and casts it to float."""
|
|
21
|
+
|
|
22
|
+
def __call__(self, array: npt.ArrayLike):
|
|
23
|
+
"""Call method for the transformation.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
array: The input numpy array.
|
|
27
|
+
"""
|
|
28
|
+
return super().__call__(array).float()
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Main interface class."""
|
|
2
|
+
|
|
3
|
+
from eva.core import trainers as eva_trainer
|
|
4
|
+
from eva.core.data import datamodules
|
|
5
|
+
from eva.core.models import modules
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Interface:
|
|
9
|
+
"""A high-level interface for training and validating a machine learning model.
|
|
10
|
+
|
|
11
|
+
This class provides a convenient interface to connect a model, data, and trainer
|
|
12
|
+
to train and validate a model.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def fit(
|
|
16
|
+
self,
|
|
17
|
+
trainer: eva_trainer.Trainer,
|
|
18
|
+
model: modules.ModelModule,
|
|
19
|
+
data: datamodules.DataModule,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Perform model training and evaluation out-of-place.
|
|
22
|
+
|
|
23
|
+
This method uses the specified trainer to fit the model using the provided data.
|
|
24
|
+
|
|
25
|
+
Example use cases:
|
|
26
|
+
|
|
27
|
+
- Using a model consisting of a frozen backbone and a head, the backbone will generate
|
|
28
|
+
the embeddings on the fly which are then used as input features to train the head on
|
|
29
|
+
the downstream task specified by the given dataset.
|
|
30
|
+
- Fitting only the head network using a dataset that loads pre-computed embeddings.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
trainer: The base trainer to use but not modify.
|
|
34
|
+
model: The model module to use but not modify.
|
|
35
|
+
data: The data module.
|
|
36
|
+
"""
|
|
37
|
+
trainer.run_evaluation_session(model=model, datamodule=data)
|
|
38
|
+
|
|
39
|
+
def predict(
|
|
40
|
+
self,
|
|
41
|
+
trainer: eva_trainer.Trainer,
|
|
42
|
+
model: modules.ModelModule,
|
|
43
|
+
data: datamodules.DataModule,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Perform model prediction out-of-place.
|
|
46
|
+
|
|
47
|
+
This method performs inference with a pre-trained foundation model to compute embeddings.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
trainer: The base trainer to use but not modify.
|
|
51
|
+
model: The model module to use but not modify.
|
|
52
|
+
data: The data module.
|
|
53
|
+
"""
|
|
54
|
+
eva_trainer.infer_model(
|
|
55
|
+
base_trainer=trainer,
|
|
56
|
+
base_model=model,
|
|
57
|
+
datamodule=data,
|
|
58
|
+
return_predictions=False,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def predict_fit(
|
|
62
|
+
self,
|
|
63
|
+
trainer: eva_trainer.Trainer,
|
|
64
|
+
model: modules.ModelModule,
|
|
65
|
+
data: datamodules.DataModule,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Combines the predict and fit commands in one method.
|
|
68
|
+
|
|
69
|
+
This method performs the following two steps:
|
|
70
|
+
1. predict: perform inference with a pre-trained foundation model to compute embeddings.
|
|
71
|
+
2. fit: training the head network using the embeddings generated in step 1.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
trainer: The base trainer to use but not modify.
|
|
75
|
+
model: The model module to use but not modify.
|
|
76
|
+
data: The data module.
|
|
77
|
+
"""
|
|
78
|
+
self.predict(trainer=trainer, model=model, data=data)
|
|
79
|
+
self.fit(trainer=trainer, model=model, data=data)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Metrics API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics.average_loss import AverageLoss
|
|
4
|
+
from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
|
|
5
|
+
from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
|
|
6
|
+
from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AverageLoss",
|
|
10
|
+
"BinaryBalancedAccuracy",
|
|
11
|
+
"Metric",
|
|
12
|
+
"MetricCollection",
|
|
13
|
+
"MetricModule",
|
|
14
|
+
"MetricsSchema",
|
|
15
|
+
"MulticlassClassificationMetrics",
|
|
16
|
+
"BinaryClassificationMetrics",
|
|
17
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Implementation of the average loss metric."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from eva.core.metrics import structs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AverageLoss(structs.Metric):
|
|
11
|
+
"""Average loss metric tracker."""
|
|
12
|
+
|
|
13
|
+
is_differentiable = True
|
|
14
|
+
higher_is_better = False
|
|
15
|
+
full_state_update = False
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
"""Initializes the metric."""
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
self.add_state("value", default=torch.tensor(0), dist_reduce_fx="sum")
|
|
22
|
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
|
23
|
+
|
|
24
|
+
@override
|
|
25
|
+
def update(self, loss: torch.Tensor) -> None:
|
|
26
|
+
_check_nans(loss)
|
|
27
|
+
total_samples = loss.numel()
|
|
28
|
+
if total_samples == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
self.value = self.value + torch.sum(loss)
|
|
32
|
+
self.total = self.total + total_samples
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def compute(self) -> torch.Tensor:
|
|
36
|
+
return self.value / self.total
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _check_nans(tensor: torch.Tensor) -> None:
|
|
40
|
+
"""Checks for nan values and raises a warning.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
Warning: If the input tensor consists of any NaN(s).
|
|
44
|
+
"""
|
|
45
|
+
nan_values = tensor.isnan()
|
|
46
|
+
if nan_values.any():
|
|
47
|
+
logger.warning("Encountered `nan` value(s) in input tensor.")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Binary balanced accuracy metric."""
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from torchmetrics.classification import stat_scores
|
|
5
|
+
from torchmetrics.utilities.compute import _safe_divide
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BinaryBalancedAccuracy(stat_scores.BinaryStatScores):
|
|
9
|
+
"""Computes the balanced accuracy for binary classification."""
|
|
10
|
+
|
|
11
|
+
is_differentiable: bool = False
|
|
12
|
+
higher_is_better: bool | None = True
|
|
13
|
+
full_state_update: bool = False
|
|
14
|
+
plot_lower_bound: float | None = 0.0
|
|
15
|
+
plot_upper_bound: float | None = 1.0
|
|
16
|
+
|
|
17
|
+
def compute(self) -> Tensor:
|
|
18
|
+
"""Compute accuracy based on inputs passed in to ``update`` previously."""
|
|
19
|
+
tp, fp, tn, fn = self._final_state()
|
|
20
|
+
sensitivity = _safe_divide(tp, tp + fn)
|
|
21
|
+
specificity = _safe_divide(tn, tn + fp)
|
|
22
|
+
return 0.5 * (sensitivity + specificity)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Default metric collections API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
|
|
4
|
+
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
|
|
5
|
+
|
|
6
|
+
__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Default classification metric collections API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
|
|
4
|
+
from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
|
|
5
|
+
|
|
6
|
+
__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
|