kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,108 @@
1
+ """Core DataModule."""
2
+
3
+ from typing import List
4
+
5
+ import lightning.pytorch as pl
6
+ from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
7
+ from typing_extensions import override
8
+
9
+ from eva.core.data import dataloaders as dataloaders_lib
10
+ from eva.core.data import datasets as datasets_lib
11
+ from eva.core.data.datamodules import call, schemas
12
+
13
+
14
+ class DataModule(pl.LightningDataModule):
15
+ """DataModule encapsulates all the steps needed to process data.
16
+
17
+ It will initialize and create the mapping between dataloaders and
18
+ datasets. During the `prepare_data`, `setup` and `teardown`, the
19
+ datamodule will call the respective methods from all datasets,
20
+ given that they are defined.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ datasets: schemas.DatasetsSchema | None = None,
26
+ dataloaders: schemas.DataloadersSchema | None = None,
27
+ ) -> None:
28
+ """Initializes the datamodule.
29
+
30
+ Args:
31
+ datasets: The desired datasets.
32
+ dataloaders: The desired dataloaders.
33
+ """
34
+ super().__init__()
35
+
36
+ self.datasets = datasets or self.default_datasets
37
+ self.dataloaders = dataloaders or self.default_dataloaders
38
+
39
+ @property
40
+ def default_datasets(self) -> schemas.DatasetsSchema:
41
+ """Returns the default datasets."""
42
+ return schemas.DatasetsSchema()
43
+
44
+ @property
45
+ def default_dataloaders(self) -> schemas.DataloadersSchema:
46
+ """Returns the default dataloader schema."""
47
+ return schemas.DataloadersSchema()
48
+
49
+ @override
50
+ def prepare_data(self) -> None:
51
+ call.call_method_if_exists(self.datasets.tolist(), "prepare_data")
52
+
53
+ @override
54
+ def setup(self, stage: str) -> None:
55
+ call.call_method_if_exists(self.datasets.tolist(stage), "setup")
56
+
57
+ @override
58
+ def teardown(self, stage: str) -> None:
59
+ call.call_method_if_exists(self.datasets.tolist(stage), "teardown")
60
+
61
+ @override
62
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
63
+ if self.datasets.train is None:
64
+ raise ValueError(
65
+ "Train dataloader can not be initialized as `self.datasets.train` is `None`."
66
+ )
67
+ return self.dataloaders.train(self.datasets.train)
68
+
69
+ @override
70
+ def val_dataloader(self) -> EVAL_DATALOADERS:
71
+ if self.datasets.val is None:
72
+ raise ValueError(
73
+ "Validation dataloader can not be initialized as `self.datasets.val` is `None`."
74
+ )
75
+ return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val)
76
+
77
+ @override
78
+ def test_dataloader(self) -> EVAL_DATALOADERS:
79
+ if self.datasets.test is None:
80
+ raise ValueError(
81
+ "Test dataloader can not be initialized as `self.datasets.test` is `None`."
82
+ )
83
+ return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test)
84
+
85
+ @override
86
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
87
+ if self.datasets.predict is None:
88
+ raise ValueError(
89
+ "Predict dataloader can not be initialized as `self.datasets.predict` is `None`."
90
+ )
91
+ return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict)
92
+
93
+ def _initialize_dataloaders(
94
+ self,
95
+ dataloader: dataloaders_lib.DataLoader,
96
+ datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset],
97
+ ) -> EVAL_DATALOADERS:
98
+ """Initializes dataloaders from a given set of dataset.
99
+
100
+ Args:
101
+ dataloader: The dataloader to apply to the provided datasets.
102
+ datasets: The desired dataset(s) to allocate dataloader(s).
103
+
104
+ Returns:
105
+ A list with the dataloaders of the provided dataset(s).
106
+ """
107
+ datasets = datasets if isinstance(datasets, list) else [datasets]
108
+ return list(map(dataloader, datasets))
@@ -0,0 +1,62 @@
1
+ """Argument schemas used in DataModule."""
2
+
3
+ import dataclasses
4
+ from typing import List
5
+
6
+ from eva.core.data import dataloaders, datasets
7
+
8
+ TRAIN_DATASET = datasets.TorchDataset | None
9
+ """Train dataset."""
10
+
11
+ EVAL_DATASET = datasets.TorchDataset | List[datasets.TorchDataset] | None
12
+ """Evaluation dataset."""
13
+
14
+
15
+ @dataclasses.dataclass(frozen=True)
16
+ class DatasetsSchema:
17
+ """Datasets schema used in DataModule."""
18
+
19
+ train: TRAIN_DATASET = None
20
+ """Train dataset."""
21
+
22
+ val: EVAL_DATASET = None
23
+ """Validation dataset."""
24
+
25
+ test: EVAL_DATASET = None
26
+ """Test dataset."""
27
+
28
+ predict: EVAL_DATASET = None
29
+ """Predict dataset."""
30
+
31
+ def tolist(self, stage: str | None = None) -> List[EVAL_DATASET]:
32
+ """Returns the dataclass as a list and optionally filters it given the stage."""
33
+ match stage:
34
+ case "fit":
35
+ return [self.train, self.val]
36
+ case "validate":
37
+ return [self.val]
38
+ case "test":
39
+ return [self.test]
40
+ case "predict":
41
+ return [self.predict]
42
+ case None:
43
+ return [self.train, self.val, self.test, self.predict]
44
+ case _:
45
+ raise ValueError(f"Invalid stage `{stage}`.")
46
+
47
+
48
+ @dataclasses.dataclass(frozen=True)
49
+ class DataloadersSchema:
50
+ """Dataloaders schema used in DataModule."""
51
+
52
+ train: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
53
+ """Train dataloader."""
54
+
55
+ val: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
56
+ """Validation dataloader."""
57
+
58
+ test: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
59
+ """Test dataloader."""
60
+
61
+ predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
62
+ """Predict dataloader."""
@@ -0,0 +1,7 @@
1
+ """Datasets API."""
2
+
3
+ from eva.core.data.datasets.base import Dataset
4
+ from eva.core.data.datasets.classification import EmbeddingsClassificationDataset
5
+ from eva.core.data.datasets.dataset import TorchDataset
6
+
7
+ __all__ = ["Dataset", "EmbeddingsClassificationDataset", "TorchDataset"]
@@ -0,0 +1,53 @@
1
+ """Base dataset class."""
2
+
3
+ from eva.core.data.datasets import dataset
4
+
5
+
6
+ class Dataset(dataset.TorchDataset):
7
+ """Base dataset class."""
8
+
9
+ def prepare_data(self) -> None:
10
+ """Encapsulates all disk related tasks.
11
+
12
+ This method is preferred for downloading and preparing the data, for
13
+ example generate manifest files. If implemented, it will be called via
14
+ :class:`eva.core.data.datamodules.DataModule`, which ensures that is called
15
+ only within a single process, making it multi-processes safe.
16
+ """
17
+
18
+ def setup(self) -> None:
19
+ """Setups the dataset.
20
+
21
+ This method is preferred for creating datasets or performing
22
+ train/val/test splits. If implemented, it will be called via
23
+ :class:`eva.core.data.datamodules.DataModule` at the beginning of fit
24
+ (train + validate), validate, test, or predict and it will be called
25
+ from every process (i.e. GPU) across all the nodes in DDP.
26
+ """
27
+ self.configure()
28
+ self.validate()
29
+
30
+ def configure(self):
31
+ """Configures the dataset.
32
+
33
+ This method is preferred to configure the dataset; assign values
34
+ to attributes, perform splits etc. This would be called from the
35
+ method ::method::`setup`, before calling the ::method::`validate`.
36
+ """
37
+
38
+ def validate(self):
39
+ """Validates the dataset.
40
+
41
+ This method aims to check the integrity of the dataset and verify
42
+ that is configured properly. This would be called from the method
43
+ ::method::`setup`, after calling the ::method::`configure`.
44
+ """
45
+
46
+ def teardown(self) -> None:
47
+ """Cleans up the data artifacts.
48
+
49
+ Used to clean-up when the run is finished. If implemented, it will
50
+ be called via :class:`eva.core.data.datamodules.DataModule` at the end
51
+ of fit (train + validate), validate, test, or predict and it will be
52
+ called from every process (i.e. GPU) across all the nodes in DDP.
53
+ """
@@ -0,0 +1,5 @@
1
+ """Classification datasets API."""
2
+
3
+ from eva.core.data.datasets.classification.embeddings import EmbeddingsClassificationDataset
4
+
5
+ __all__ = ["EmbeddingsClassificationDataset"]
@@ -0,0 +1,154 @@
1
+ """Embeddings classification dataset."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from typing_extensions import override
10
+
11
+ from eva.core.data.datasets import base
12
+ from eva.core.utils import io
13
+
14
+
15
+ class EmbeddingsClassificationDataset(base.Dataset):
16
+ """Embeddings classification dataset."""
17
+
18
+ default_column_mapping: Dict[str, str] = {
19
+ "data": "embeddings",
20
+ "target": "target",
21
+ "split": "split",
22
+ }
23
+ """The default column mapping of the variables to the manifest columns."""
24
+
25
+ def __init__(
26
+ self,
27
+ root: str,
28
+ manifest_file: str,
29
+ split: str | None = None,
30
+ column_mapping: Dict[str, str] = default_column_mapping,
31
+ embeddings_transforms: Callable | None = None,
32
+ target_transforms: Callable | None = None,
33
+ ) -> None:
34
+ """Initialize dataset.
35
+
36
+ Expects a manifest file listing the paths of .pt files that contain
37
+ tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
38
+
39
+ Args:
40
+ root: Root directory of the dataset.
41
+ manifest_file: The path to the manifest file, which is relative to
42
+ the `root` argument.
43
+ split: The dataset split to use. The `split` column of the manifest
44
+ file will be splitted based on this value.
45
+ column_mapping: Defines the map between the variables and the manifest
46
+ columns. It will overwrite the `default_column_mapping` with
47
+ the provided values, so that `column_mapping` can contain only the
48
+ values which are altered or missing.
49
+ embeddings_transforms: A function/transform that transforms the embedding.
50
+ target_transforms: A function/transform that transforms the target.
51
+ """
52
+ super().__init__()
53
+
54
+ self._root = root
55
+ self._manifest_file = manifest_file
56
+ self._split = split
57
+ self._column_mapping = self.default_column_mapping | column_mapping
58
+ self._embeddings_transforms = embeddings_transforms
59
+ self._target_transforms = target_transforms
60
+
61
+ self._data: pd.DataFrame
62
+
63
+ def filename(self, index: int) -> str:
64
+ """Returns the filename of the `index`'th data sample.
65
+
66
+ Note that this is the relative file path to the root.
67
+
68
+ Args:
69
+ index: The index of the data-sample to select.
70
+
71
+ Returns:
72
+ The filename of the `index`'th data sample.
73
+ """
74
+ return self._data.at[index, self._column_mapping["data"]]
75
+
76
+ @override
77
+ def setup(self):
78
+ self._data = self._load_manifest()
79
+
80
+ def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
81
+ """Returns the `index`'th data sample.
82
+
83
+ Args:
84
+ index: The index of the data-sample to select.
85
+
86
+ Returns:
87
+ A data sample and its target.
88
+ """
89
+ embeddings = self._load_embeddings(index)
90
+ target = self._load_target(index)
91
+ return self._apply_transforms(embeddings, target)
92
+
93
+ def __len__(self) -> int:
94
+ """Returns the total length of the data."""
95
+ return len(self._data)
96
+
97
+ def _load_embeddings(self, index: int) -> torch.Tensor:
98
+ """Returns the `index`'th embedding sample.
99
+
100
+ Args:
101
+ index: The index of the data sample to load.
102
+
103
+ Returns:
104
+ The sample embedding as an array.
105
+ """
106
+ filename = self.filename(index)
107
+ embeddings_path = os.path.join(self._root, filename)
108
+ tensor = torch.load(embeddings_path, map_location="cpu")
109
+ return tensor.squeeze(0)
110
+
111
+ def _load_target(self, index: int) -> np.ndarray:
112
+ """Returns the `index`'th target sample.
113
+
114
+ Args:
115
+ index: The index of the data sample to load.
116
+
117
+ Returns:
118
+ The sample target as an array.
119
+ """
120
+ target = self._data.at[index, self._column_mapping["target"]]
121
+ return np.asarray(target, dtype=np.int64)
122
+
123
+ def _load_manifest(self) -> pd.DataFrame:
124
+ """Loads manifest file and filters the data based on the split column.
125
+
126
+ Returns:
127
+ The data as a pandas DataFrame.
128
+ """
129
+ manifest_path = os.path.join(self._root, self._manifest_file)
130
+ data = io.read_dataframe(manifest_path)
131
+ if self._split is not None:
132
+ filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
133
+ data = filtered_data.reset_index(drop=True)
134
+ return data
135
+
136
+ def _apply_transforms(
137
+ self, embeddings: torch.Tensor, target: np.ndarray
138
+ ) -> Tuple[torch.Tensor, np.ndarray]:
139
+ """Applies the transforms to the provided data and returns them.
140
+
141
+ Args:
142
+ embeddings: The embeddings to be transformed.
143
+ target: The training target.
144
+
145
+ Returns:
146
+ A tuple with the embeddings and the target transformed.
147
+ """
148
+ if self._embeddings_transforms is not None:
149
+ embeddings = self._embeddings_transforms(embeddings)
150
+
151
+ if self._target_transforms is not None:
152
+ target = self._target_transforms(target)
153
+
154
+ return embeddings, target
@@ -0,0 +1,6 @@
1
+ """Base torch dataset module."""
2
+
3
+ from torch.utils import data
4
+
5
+ TorchDataset = data.Dataset
6
+ """Base torch dataset class."""
@@ -0,0 +1,5 @@
1
+ """Data samplers API."""
2
+
3
+ from eva.core.data.samplers.sampler import Sampler
4
+
5
+ __all__ = ["Sampler"]
@@ -0,0 +1,6 @@
1
+ """Core data sampler."""
2
+
3
+ from torch.utils import data
4
+
5
+ Sampler = data.Sampler
6
+ """Core abstract data sampler class."""
@@ -0,0 +1,5 @@
1
+ """Core data transforms."""
2
+
3
+ from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
4
+
5
+ __all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
@@ -0,0 +1,5 @@
1
+ """Type casting related transforms."""
2
+
3
+ from eva.core.data.transforms.dtype.array import ArrayToFloatTensor, ArrayToTensor
4
+
5
+ __all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
@@ -0,0 +1,28 @@
1
+ """Transformations to convert numpy arrays to torch tensors."""
2
+
3
+ import numpy.typing as npt
4
+ import torch
5
+
6
+
7
+ class ArrayToTensor:
8
+ """Converts a numpy array to a torch tensor."""
9
+
10
+ def __call__(self, array: npt.ArrayLike) -> torch.Tensor:
11
+ """Call method for the transformation.
12
+
13
+ Args:
14
+ array: The input numpy array.
15
+ """
16
+ return torch.from_numpy(array)
17
+
18
+
19
+ class ArrayToFloatTensor(ArrayToTensor):
20
+ """Converts a numpy array to a torch tensor and casts it to float."""
21
+
22
+ def __call__(self, array: npt.ArrayLike):
23
+ """Call method for the transformation.
24
+
25
+ Args:
26
+ array: The input numpy array.
27
+ """
28
+ return super().__call__(array).float()
@@ -0,0 +1,5 @@
1
+ """Interface API."""
2
+
3
+ from eva.core.interface.interface import Interface
4
+
5
+ __all__ = ["Interface"]
@@ -0,0 +1,79 @@
1
+ """Main interface class."""
2
+
3
+ from eva.core import trainers as eva_trainer
4
+ from eva.core.data import datamodules
5
+ from eva.core.models import modules
6
+
7
+
8
+ class Interface:
9
+ """A high-level interface for training and validating a machine learning model.
10
+
11
+ This class provides a convenient interface to connect a model, data, and trainer
12
+ to train and validate a model.
13
+ """
14
+
15
+ def fit(
16
+ self,
17
+ trainer: eva_trainer.Trainer,
18
+ model: modules.ModelModule,
19
+ data: datamodules.DataModule,
20
+ ) -> None:
21
+ """Perform model training and evaluation out-of-place.
22
+
23
+ This method uses the specified trainer to fit the model using the provided data.
24
+
25
+ Example use cases:
26
+
27
+ - Using a model consisting of a frozen backbone and a head, the backbone will generate
28
+ the embeddings on the fly which are then used as input features to train the head on
29
+ the downstream task specified by the given dataset.
30
+ - Fitting only the head network using a dataset that loads pre-computed embeddings.
31
+
32
+ Args:
33
+ trainer: The base trainer to use but not modify.
34
+ model: The model module to use but not modify.
35
+ data: The data module.
36
+ """
37
+ trainer.run_evaluation_session(model=model, datamodule=data)
38
+
39
+ def predict(
40
+ self,
41
+ trainer: eva_trainer.Trainer,
42
+ model: modules.ModelModule,
43
+ data: datamodules.DataModule,
44
+ ) -> None:
45
+ """Perform model prediction out-of-place.
46
+
47
+ This method performs inference with a pre-trained foundation model to compute embeddings.
48
+
49
+ Args:
50
+ trainer: The base trainer to use but not modify.
51
+ model: The model module to use but not modify.
52
+ data: The data module.
53
+ """
54
+ eva_trainer.infer_model(
55
+ base_trainer=trainer,
56
+ base_model=model,
57
+ datamodule=data,
58
+ return_predictions=False,
59
+ )
60
+
61
+ def predict_fit(
62
+ self,
63
+ trainer: eva_trainer.Trainer,
64
+ model: modules.ModelModule,
65
+ data: datamodules.DataModule,
66
+ ) -> None:
67
+ """Combines the predict and fit commands in one method.
68
+
69
+ This method performs the following two steps:
70
+ 1. predict: perform inference with a pre-trained foundation model to compute embeddings.
71
+ 2. fit: training the head network using the embeddings generated in step 1.
72
+
73
+ Args:
74
+ trainer: The base trainer to use but not modify.
75
+ model: The model module to use but not modify.
76
+ data: The data module.
77
+ """
78
+ self.predict(trainer=trainer, model=model, data=data)
79
+ self.fit(trainer=trainer, model=model, data=data)
@@ -0,0 +1,17 @@
1
+ """Metrics API."""
2
+
3
+ from eva.core.metrics.average_loss import AverageLoss
4
+ from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
5
+ from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
6
+ from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
7
+
8
+ __all__ = [
9
+ "AverageLoss",
10
+ "BinaryBalancedAccuracy",
11
+ "Metric",
12
+ "MetricCollection",
13
+ "MetricModule",
14
+ "MetricsSchema",
15
+ "MulticlassClassificationMetrics",
16
+ "BinaryClassificationMetrics",
17
+ ]
@@ -0,0 +1,47 @@
1
+ """Implementation of the average loss metric."""
2
+
3
+ import torch
4
+ from loguru import logger
5
+ from typing_extensions import override
6
+
7
+ from eva.core.metrics import structs
8
+
9
+
10
+ class AverageLoss(structs.Metric):
11
+ """Average loss metric tracker."""
12
+
13
+ is_differentiable = True
14
+ higher_is_better = False
15
+ full_state_update = False
16
+
17
+ def __init__(self) -> None:
18
+ """Initializes the metric."""
19
+ super().__init__()
20
+
21
+ self.add_state("value", default=torch.tensor(0), dist_reduce_fx="sum")
22
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
23
+
24
+ @override
25
+ def update(self, loss: torch.Tensor) -> None:
26
+ _check_nans(loss)
27
+ total_samples = loss.numel()
28
+ if total_samples == 0:
29
+ return
30
+
31
+ self.value = self.value + torch.sum(loss)
32
+ self.total = self.total + total_samples
33
+
34
+ @override
35
+ def compute(self) -> torch.Tensor:
36
+ return self.value / self.total
37
+
38
+
39
+ def _check_nans(tensor: torch.Tensor) -> None:
40
+ """Checks for nan values and raises a warning.
41
+
42
+ Raises:
43
+ Warning: If the input tensor consists of any NaN(s).
44
+ """
45
+ nan_values = tensor.isnan()
46
+ if nan_values.any():
47
+ logger.warning("Encountered `nan` value(s) in input tensor.")
@@ -0,0 +1,22 @@
1
+ """Binary balanced accuracy metric."""
2
+
3
+ from torch import Tensor
4
+ from torchmetrics.classification import stat_scores
5
+ from torchmetrics.utilities.compute import _safe_divide
6
+
7
+
8
+ class BinaryBalancedAccuracy(stat_scores.BinaryStatScores):
9
+ """Computes the balanced accuracy for binary classification."""
10
+
11
+ is_differentiable: bool = False
12
+ higher_is_better: bool | None = True
13
+ full_state_update: bool = False
14
+ plot_lower_bound: float | None = 0.0
15
+ plot_upper_bound: float | None = 1.0
16
+
17
+ def compute(self) -> Tensor:
18
+ """Compute accuracy based on inputs passed in to ``update`` previously."""
19
+ tp, fp, tn, fn = self._final_state()
20
+ sensitivity = _safe_divide(tp, tp + fn)
21
+ specificity = _safe_divide(tn, tn + fp)
22
+ return 0.5 * (sensitivity + specificity)
@@ -0,0 +1,6 @@
1
+ """Default metric collections API."""
2
+
3
+ from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
+ from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
5
+
6
+ __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]
@@ -0,0 +1,6 @@
1
+ """Default classification metric collections API."""
2
+
3
+ from eva.core.metrics.defaults.classification.binary import BinaryClassificationMetrics
4
+ from eva.core.metrics.defaults.classification.multiclass import MulticlassClassificationMetrics
5
+
6
+ __all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics"]