kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Base model module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping
|
|
4
|
+
|
|
5
|
+
import lightning.pytorch as pl
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.pytorch.utilities import memory
|
|
8
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.core.metrics import structs as metrics_lib
|
|
12
|
+
from eva.core.models.modules.typings import INPUT_BATCH
|
|
13
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelModule(pl.LightningModule):
|
|
17
|
+
"""The base model module."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
22
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the basic module.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
metrics: The metric groups to track.
|
|
28
|
+
postprocess: A list of helper functions to apply after the
|
|
29
|
+
loss and before the metrics calculation to the model
|
|
30
|
+
predictions and targets.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self._metrics = metrics or self.default_metrics
|
|
35
|
+
self._postprocess = postprocess or self.default_postprocess
|
|
36
|
+
|
|
37
|
+
self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def default_metrics(self) -> metrics_lib.MetricsSchema:
|
|
41
|
+
"""The default metrics."""
|
|
42
|
+
return metrics_lib.MetricsSchema()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def default_postprocess(self) -> batch_postprocess.BatchPostProcess:
|
|
46
|
+
"""The default post-processes."""
|
|
47
|
+
return batch_postprocess.BatchPostProcess()
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def on_train_batch_end(
|
|
51
|
+
self,
|
|
52
|
+
outputs: STEP_OUTPUT,
|
|
53
|
+
batch: INPUT_BATCH,
|
|
54
|
+
batch_idx: int,
|
|
55
|
+
) -> None:
|
|
56
|
+
outputs = self._common_batch_end(outputs)
|
|
57
|
+
self._forward_and_log_metrics(
|
|
58
|
+
self.metrics.training_metrics,
|
|
59
|
+
batch_outputs=outputs,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def on_validation_batch_end(
|
|
64
|
+
self,
|
|
65
|
+
outputs: STEP_OUTPUT,
|
|
66
|
+
batch: INPUT_BATCH,
|
|
67
|
+
batch_idx: int,
|
|
68
|
+
dataloader_idx: int = 0,
|
|
69
|
+
) -> None:
|
|
70
|
+
outputs = self._common_batch_end(outputs)
|
|
71
|
+
self._update_metrics(
|
|
72
|
+
self.metrics.validation_metrics,
|
|
73
|
+
outputs=outputs,
|
|
74
|
+
dataloader_idx=dataloader_idx,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def on_validation_epoch_end(self) -> None:
|
|
79
|
+
self._compute_and_log_metrics(self.metrics.validation_metrics)
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def on_test_batch_end(
|
|
83
|
+
self,
|
|
84
|
+
outputs: STEP_OUTPUT,
|
|
85
|
+
batch: INPUT_BATCH,
|
|
86
|
+
batch_idx: int,
|
|
87
|
+
dataloader_idx: int = 0,
|
|
88
|
+
) -> None:
|
|
89
|
+
outputs = self._common_batch_end(outputs)
|
|
90
|
+
self._update_metrics(
|
|
91
|
+
self.metrics.test_metrics,
|
|
92
|
+
outputs=outputs,
|
|
93
|
+
dataloader_idx=dataloader_idx,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def on_test_epoch_end(self) -> None:
|
|
98
|
+
self._compute_and_log_metrics(self.metrics.test_metrics)
|
|
99
|
+
|
|
100
|
+
def _common_batch_end(self, outputs: STEP_OUTPUT) -> STEP_OUTPUT:
|
|
101
|
+
"""Common end step of training, validation and test.
|
|
102
|
+
|
|
103
|
+
It will apply the post-processes to the batch output and move
|
|
104
|
+
them to the appropriate device.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
outputs: The batch step outputs.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
The updated outputs.
|
|
111
|
+
"""
|
|
112
|
+
self._postprocess(outputs)
|
|
113
|
+
return memory.recursive_detach(outputs, to_cpu=self.device.type == "cpu")
|
|
114
|
+
|
|
115
|
+
def _forward_and_log_metrics(
|
|
116
|
+
self,
|
|
117
|
+
metrics: metrics_lib.MetricCollection,
|
|
118
|
+
batch_outputs: STEP_OUTPUT,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Performs a forward pass to the metrics and logs them.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
metrics: The desired metrics tracker to log.
|
|
124
|
+
batch_outputs: The outputs of the batch processing step.
|
|
125
|
+
"""
|
|
126
|
+
inputs = self._parse_metrics_inputs(batch_outputs)
|
|
127
|
+
metrics(**inputs)
|
|
128
|
+
self.log_dict(metrics, on_step=True, on_epoch=False)
|
|
129
|
+
|
|
130
|
+
def _update_metrics(
|
|
131
|
+
self,
|
|
132
|
+
metrics: metrics_lib.MetricCollection,
|
|
133
|
+
outputs: STEP_OUTPUT,
|
|
134
|
+
dataloader_idx: int = 0,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Updates the metrics tracker with new data.
|
|
137
|
+
|
|
138
|
+
Here the `outputs` keyword values will be filtered based
|
|
139
|
+
on the signature of all individual metrics and passed only
|
|
140
|
+
to the compatible ones.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
metrics: The desired metrics tracker to update.
|
|
144
|
+
outputs: The outputs of the batch processing step.
|
|
145
|
+
dataloader_idx: The dataloader index.
|
|
146
|
+
"""
|
|
147
|
+
inputs = self._parse_metrics_inputs(outputs, dataloader_idx)
|
|
148
|
+
metrics.update(**inputs)
|
|
149
|
+
|
|
150
|
+
def _compute_and_log_metrics(self, metrics: metrics_lib.MetricCollection) -> None:
|
|
151
|
+
"""Computes, logs and resets the metrics.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
metrics: The desired metrics tracker to log.
|
|
155
|
+
"""
|
|
156
|
+
outputs = metrics.compute()
|
|
157
|
+
self.log_dict(outputs)
|
|
158
|
+
metrics.reset()
|
|
159
|
+
|
|
160
|
+
def _parse_metrics_inputs(
|
|
161
|
+
self,
|
|
162
|
+
outputs: STEP_OUTPUT,
|
|
163
|
+
dataloader_idx: int = 0,
|
|
164
|
+
) -> Mapping[str, Any]:
|
|
165
|
+
"""Parses the arguments for the metrics.
|
|
166
|
+
|
|
167
|
+
When pass to a metrics collection object, the keyword values
|
|
168
|
+
will be filtered based on the signature of all individual
|
|
169
|
+
metrics and passed only to the compatible ones.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
outputs: The outputs of the batch processing step.
|
|
173
|
+
dataloader_idx: The dataloader index.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
A mapping with the argument name and its value.
|
|
177
|
+
"""
|
|
178
|
+
if outputs is None:
|
|
179
|
+
return {}
|
|
180
|
+
|
|
181
|
+
if isinstance(outputs, torch.Tensor):
|
|
182
|
+
outputs = {"loss": outputs}
|
|
183
|
+
|
|
184
|
+
additional_metric_inputs = {
|
|
185
|
+
"preds": outputs.get("predictions"),
|
|
186
|
+
"target": outputs.get("targets"),
|
|
187
|
+
"metadata": outputs.get("metadata"),
|
|
188
|
+
"dataloader_idx": dataloader_idx,
|
|
189
|
+
}
|
|
190
|
+
return {**additional_metric_inputs, **outputs}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Type annotations for model modules."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, NamedTuple
|
|
4
|
+
|
|
5
|
+
import lightning.pytorch as pl
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
MODEL_TYPE = nn.Module | pl.LightningModule
|
|
10
|
+
"""The expected model type."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class INPUT_BATCH(NamedTuple):
|
|
14
|
+
"""The default input batch data scheme."""
|
|
15
|
+
|
|
16
|
+
data: torch.Tensor
|
|
17
|
+
"""The data batch."""
|
|
18
|
+
|
|
19
|
+
targets: torch.Tensor | Dict[str, Any] | None = None
|
|
20
|
+
"""The target batch."""
|
|
21
|
+
|
|
22
|
+
metadata: Dict[str, Any] | None = None
|
|
23
|
+
"""The associated metadata."""
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Batch post-processes module."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import functools
|
|
5
|
+
from typing import Callable, List
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
9
|
+
|
|
10
|
+
Transform = Callable[[torch.Tensor], torch.Tensor]
|
|
11
|
+
"""Post-process transform type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(frozen=True)
|
|
15
|
+
class BatchPostProcess:
|
|
16
|
+
"""Batch post-processes transform schema."""
|
|
17
|
+
|
|
18
|
+
targets_transforms: List[Transform] | None = None
|
|
19
|
+
"""Holds the common train and evaluation metrics."""
|
|
20
|
+
|
|
21
|
+
predictions_transforms: List[Transform] | None = None
|
|
22
|
+
"""Holds the common train and evaluation metrics."""
|
|
23
|
+
|
|
24
|
+
def __call__(self, outputs: STEP_OUTPUT) -> None:
|
|
25
|
+
"""Applies the defined list of transforms to the batch output in-place.
|
|
26
|
+
|
|
27
|
+
Note that the transforms are applied only when the input is a dictionary
|
|
28
|
+
and only to its keys of `predictions` and/or `targets`.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
outputs: The batch output of the model module step.
|
|
32
|
+
"""
|
|
33
|
+
if not isinstance(outputs, dict):
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
if "targets" in outputs and self.targets_transforms is not None:
|
|
37
|
+
outputs["targets"] = _apply_transforms(
|
|
38
|
+
outputs["targets"], transforms=self.targets_transforms
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if "predictions" in outputs and self.predictions_transforms is not None:
|
|
42
|
+
outputs["predictions"] = _apply_transforms(
|
|
43
|
+
outputs["predictions"], transforms=self.predictions_transforms
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torch.Tensor:
|
|
48
|
+
"""Applies a list of transforms sequentially to a input tensor.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
tensor: The desired tensor to process.
|
|
52
|
+
transforms: The list of transforms to apply to the input tensor.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The processed tensor.
|
|
56
|
+
"""
|
|
57
|
+
return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Gradient related utilities and helper functions."""
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def deactivate_requires_grad(module: nn.Module) -> None:
|
|
7
|
+
"""Deactivates the `requires_grad` flag for all parameters of a model.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
module: The torch module to deactivate the gradient computation in place.
|
|
11
|
+
"""
|
|
12
|
+
for parameter in module.parameters():
|
|
13
|
+
parameter.requires_grad = False
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def activate_requires_grad(module: nn.Module) -> None:
|
|
17
|
+
"""Activates the `requires_grad` flag for all parameters of a model.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
module: The torch module to deactivate the gradient computation in place.
|
|
21
|
+
"""
|
|
22
|
+
for parameter in module.parameters():
|
|
23
|
+
parameter.requires_grad = True
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Utilities and helper functions for models."""
|
|
2
|
+
|
|
3
|
+
from lightning_fabric.utilities import cloud_io
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
|
|
9
|
+
"""Loads (local or remote) weights to the model in-place.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
model: The model to load the weights to.
|
|
13
|
+
checkpoint_path: The path to the model weights/checkpoint.
|
|
14
|
+
"""
|
|
15
|
+
logger.info(f"Loading '{model.__class__.__name__}' model from checkpoint '{checkpoint_path}'")
|
|
16
|
+
|
|
17
|
+
fs = cloud_io.get_filesystem(checkpoint_path)
|
|
18
|
+
with fs.open(checkpoint_path, "rb") as file:
|
|
19
|
+
checkpoint = cloud_io._load(file, map_location="cpu") # type: ignore
|
|
20
|
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
|
21
|
+
checkpoint = checkpoint["state_dict"]
|
|
22
|
+
|
|
23
|
+
model.load_state_dict(checkpoint, strict=True)
|
|
24
|
+
|
|
25
|
+
logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Multi-layer Perceptron (MLP) implemented in PyTorch."""
|
|
2
|
+
|
|
3
|
+
from typing import Type
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MLP(nn.Module):
|
|
10
|
+
"""A Multi-layer Perceptron (MLP) network."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
input_size: int,
|
|
15
|
+
output_size: int,
|
|
16
|
+
hidden_layer_sizes: tuple[int, ...] | None = None,
|
|
17
|
+
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
|
|
18
|
+
output_activation_fn: Type[torch.nn.Module] | None = None,
|
|
19
|
+
dropout: float = 0.0,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Initializes the MLP.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
input_size: The number of input features.
|
|
25
|
+
output_size: The number of output features.
|
|
26
|
+
hidden_layer_sizes: A list specifying the number of units in each hidden layer.
|
|
27
|
+
dropout: Dropout probability for hidden layers.
|
|
28
|
+
hidden_activation_fn: Activation function to use for hidden layers. Default is ReLU.
|
|
29
|
+
output_activation_fn: Activation function to use for the output layer. Default is None.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self.input_size = input_size
|
|
34
|
+
self.output_size = output_size
|
|
35
|
+
self.hidden_layer_sizes = hidden_layer_sizes if hidden_layer_sizes is not None else ()
|
|
36
|
+
self.hidden_activation_fn = hidden_activation_fn
|
|
37
|
+
self.output_activation_fn = output_activation_fn
|
|
38
|
+
self.dropout = dropout
|
|
39
|
+
|
|
40
|
+
self._network = self._build_network()
|
|
41
|
+
|
|
42
|
+
def _build_network(self) -> nn.Sequential:
|
|
43
|
+
"""Builds the neural network's layers and returns a nn.Sequential container."""
|
|
44
|
+
layers = []
|
|
45
|
+
prev_size = self.input_size
|
|
46
|
+
for size in self.hidden_layer_sizes:
|
|
47
|
+
layers.append(nn.Linear(prev_size, size))
|
|
48
|
+
if self.hidden_activation_fn is not None:
|
|
49
|
+
layers.append(self.hidden_activation_fn())
|
|
50
|
+
if self.dropout > 0:
|
|
51
|
+
layers.append(nn.Dropout(self.dropout))
|
|
52
|
+
prev_size = size
|
|
53
|
+
|
|
54
|
+
layers.append(nn.Linear(prev_size, self.output_size))
|
|
55
|
+
if self.output_activation_fn is not None:
|
|
56
|
+
layers.append(self.output_activation_fn())
|
|
57
|
+
|
|
58
|
+
return nn.Sequential(*layers)
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""Defines the forward pass of the MLP.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
x: The input tensor.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
The output of the network.
|
|
68
|
+
"""
|
|
69
|
+
return self._network(x)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Transforms for extracting the CLS output from a model output."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from transformers import modeling_outputs
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExtractCLSFeatures:
|
|
8
|
+
"""Extracts the CLS token from a ViT model output."""
|
|
9
|
+
|
|
10
|
+
def __call__(
|
|
11
|
+
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
|
|
12
|
+
) -> torch.Tensor:
|
|
13
|
+
"""Call method for the transformation.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
tensor: The tensor representing the model output.
|
|
17
|
+
"""
|
|
18
|
+
if isinstance(tensor, torch.Tensor):
|
|
19
|
+
transformed_tensor = tensor[:, 0, :]
|
|
20
|
+
elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
21
|
+
transformed_tensor = tensor.last_hidden_state[:, 0, :]
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(f"Unsupported type {type(tensor)}")
|
|
24
|
+
|
|
25
|
+
return transformed_tensor
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""Model Wrappers API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.models.networks.wrappers.base import BaseModel
|
|
4
|
+
from eva.core.models.networks.wrappers.from_function import ModelFromFunction
|
|
5
|
+
from eva.core.models.networks.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.core.models.networks.wrappers.onnx import ONNXModel
|
|
7
|
+
|
|
8
|
+
__all__ = ["BaseModel", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Base class for model wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseModel(nn.Module):
|
|
12
|
+
"""Base class for model wrappers."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, tensor_transforms: Callable | None = None) -> None:
|
|
15
|
+
"""Initializes the model.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
tensor_transforms: The transforms to apply to the output
|
|
19
|
+
tensor produced by the model.
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self._output_transforms = tensor_transforms
|
|
24
|
+
|
|
25
|
+
@override
|
|
26
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
tensor = self.model_forward(tensor)
|
|
28
|
+
return self._apply_transforms(tensor)
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def load_model(self) -> Callable[..., torch.Tensor]:
|
|
32
|
+
"""Loads the model."""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
"""Implements the forward pass of the model.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
tensor: The input tensor to the model.
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
if self._output_transforms is not None:
|
|
46
|
+
tensor = self._output_transforms(tensor)
|
|
47
|
+
return tensor
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Helper function from models defined with a function."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
|
+
|
|
5
|
+
import jsonargparse
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.models.networks import _utils
|
|
11
|
+
from eva.core.models.networks.wrappers import base
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFromFunction(base.BaseModel):
|
|
15
|
+
"""Wrapper class for models which are initialized from functions.
|
|
16
|
+
|
|
17
|
+
This is helpful for initializing models in a `.yaml` configuration file.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
path: Callable[..., nn.Module],
|
|
23
|
+
arguments: Dict[str, Any] | None = None,
|
|
24
|
+
checkpoint_path: str | None = None,
|
|
25
|
+
tensor_transforms: Callable | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes and constructs the model.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
path: The path to the callable object (class or function).
|
|
31
|
+
arguments: The extra callable function / class arguments.
|
|
32
|
+
checkpoint_path: The path to the checkpoint to load the model
|
|
33
|
+
weights from. This is currently only supported for torch
|
|
34
|
+
model checkpoints. For other formats, the checkpoint loading
|
|
35
|
+
should be handled within the provided callable object in <path>.
|
|
36
|
+
tensor_transforms: The transforms to apply to the output tensor
|
|
37
|
+
produced by the model.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self._path = path
|
|
42
|
+
self._arguments = arguments
|
|
43
|
+
self._checkpoint_path = checkpoint_path
|
|
44
|
+
self._tensor_transforms = tensor_transforms
|
|
45
|
+
|
|
46
|
+
self._model = self.load_model()
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def load_model(self) -> nn.Module:
|
|
50
|
+
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
|
|
51
|
+
model = class_path(**self._arguments or {})
|
|
52
|
+
if self._checkpoint_path is not None:
|
|
53
|
+
_utils.load_model_weights(model, self._checkpoint_path)
|
|
54
|
+
return model
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
return self._model(tensor)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Wrappers for HuggingFace `transformers` models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import transformers
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models.networks.wrappers import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HuggingFaceModel(base.BaseModel):
|
|
13
|
+
"""Wrapper class for loading HuggingFace `transformers` models."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
|
|
16
|
+
"""Initializes the model.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model_name_or_path: The model name or path to load the model from.
|
|
20
|
+
This can be a local path or a model name from the `HuggingFace`
|
|
21
|
+
model hub.
|
|
22
|
+
tensor_transforms: The transforms to apply to the output tensor
|
|
23
|
+
produced by the model.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
26
|
+
|
|
27
|
+
self._model_name_or_path = model_name_or_path
|
|
28
|
+
self._model = self.load_model()
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def load_model(self) -> Any:
|
|
32
|
+
config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
|
|
33
|
+
return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
return self._model(tensor)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Wrapper class for ONNX models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Literal
|
|
4
|
+
|
|
5
|
+
import onnxruntime as ort
|
|
6
|
+
import torch
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models.networks.wrappers import base
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ONNXModel(base.BaseModel):
|
|
13
|
+
"""Wrapper class for loading ONNX models."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
path: str,
|
|
18
|
+
device: Literal["cpu", "cuda"] | None = "cpu",
|
|
19
|
+
tensor_transforms: Callable | None = None,
|
|
20
|
+
):
|
|
21
|
+
"""Initializes the model.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
path: The path to the .onnx model file.
|
|
25
|
+
device: The device to run the model on. This can be either "cpu" or "cuda".
|
|
26
|
+
tensor_transforms: The transforms to apply to the output tensor produced by the model.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
29
|
+
|
|
30
|
+
self._path = path
|
|
31
|
+
self._device = device
|
|
32
|
+
self._model = self.load_model()
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def load_model(self) -> Any:
|
|
36
|
+
if self._device == "cuda" and not torch.cuda.is_available():
|
|
37
|
+
raise ValueError("Device is set to 'cuda', but CUDA is not available.")
|
|
38
|
+
provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
|
|
39
|
+
return ort.InferenceSession(self._path, providers=[provider])
|
|
40
|
+
|
|
41
|
+
@override
|
|
42
|
+
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
# TODO: Use IO binding to avoid copying the tensor to CPU.
|
|
44
|
+
# https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
|
|
45
|
+
inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
|
|
46
|
+
outputs = self._model.run(None, inputs)[0]
|
|
47
|
+
return torch.from_numpy(outputs).float().to(tensor.device)
|