kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,190 @@
1
+ """Base model module."""
2
+
3
+ from typing import Any, Mapping
4
+
5
+ import lightning.pytorch as pl
6
+ import torch
7
+ from lightning.pytorch.utilities import memory
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+ from typing_extensions import override
10
+
11
+ from eva.core.metrics import structs as metrics_lib
12
+ from eva.core.models.modules.typings import INPUT_BATCH
13
+ from eva.core.models.modules.utils import batch_postprocess
14
+
15
+
16
+ class ModelModule(pl.LightningModule):
17
+ """The base model module."""
18
+
19
+ def __init__(
20
+ self,
21
+ metrics: metrics_lib.MetricsSchema | None = None,
22
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
23
+ ) -> None:
24
+ """Initializes the basic module.
25
+
26
+ Args:
27
+ metrics: The metric groups to track.
28
+ postprocess: A list of helper functions to apply after the
29
+ loss and before the metrics calculation to the model
30
+ predictions and targets.
31
+ """
32
+ super().__init__()
33
+
34
+ self._metrics = metrics or self.default_metrics
35
+ self._postprocess = postprocess or self.default_postprocess
36
+
37
+ self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
38
+
39
+ @property
40
+ def default_metrics(self) -> metrics_lib.MetricsSchema:
41
+ """The default metrics."""
42
+ return metrics_lib.MetricsSchema()
43
+
44
+ @property
45
+ def default_postprocess(self) -> batch_postprocess.BatchPostProcess:
46
+ """The default post-processes."""
47
+ return batch_postprocess.BatchPostProcess()
48
+
49
+ @override
50
+ def on_train_batch_end(
51
+ self,
52
+ outputs: STEP_OUTPUT,
53
+ batch: INPUT_BATCH,
54
+ batch_idx: int,
55
+ ) -> None:
56
+ outputs = self._common_batch_end(outputs)
57
+ self._forward_and_log_metrics(
58
+ self.metrics.training_metrics,
59
+ batch_outputs=outputs,
60
+ )
61
+
62
+ @override
63
+ def on_validation_batch_end(
64
+ self,
65
+ outputs: STEP_OUTPUT,
66
+ batch: INPUT_BATCH,
67
+ batch_idx: int,
68
+ dataloader_idx: int = 0,
69
+ ) -> None:
70
+ outputs = self._common_batch_end(outputs)
71
+ self._update_metrics(
72
+ self.metrics.validation_metrics,
73
+ outputs=outputs,
74
+ dataloader_idx=dataloader_idx,
75
+ )
76
+
77
+ @override
78
+ def on_validation_epoch_end(self) -> None:
79
+ self._compute_and_log_metrics(self.metrics.validation_metrics)
80
+
81
+ @override
82
+ def on_test_batch_end(
83
+ self,
84
+ outputs: STEP_OUTPUT,
85
+ batch: INPUT_BATCH,
86
+ batch_idx: int,
87
+ dataloader_idx: int = 0,
88
+ ) -> None:
89
+ outputs = self._common_batch_end(outputs)
90
+ self._update_metrics(
91
+ self.metrics.test_metrics,
92
+ outputs=outputs,
93
+ dataloader_idx=dataloader_idx,
94
+ )
95
+
96
+ @override
97
+ def on_test_epoch_end(self) -> None:
98
+ self._compute_and_log_metrics(self.metrics.test_metrics)
99
+
100
+ def _common_batch_end(self, outputs: STEP_OUTPUT) -> STEP_OUTPUT:
101
+ """Common end step of training, validation and test.
102
+
103
+ It will apply the post-processes to the batch output and move
104
+ them to the appropriate device.
105
+
106
+ Args:
107
+ outputs: The batch step outputs.
108
+
109
+ Returns:
110
+ The updated outputs.
111
+ """
112
+ self._postprocess(outputs)
113
+ return memory.recursive_detach(outputs, to_cpu=self.device.type == "cpu")
114
+
115
+ def _forward_and_log_metrics(
116
+ self,
117
+ metrics: metrics_lib.MetricCollection,
118
+ batch_outputs: STEP_OUTPUT,
119
+ ) -> None:
120
+ """Performs a forward pass to the metrics and logs them.
121
+
122
+ Args:
123
+ metrics: The desired metrics tracker to log.
124
+ batch_outputs: The outputs of the batch processing step.
125
+ """
126
+ inputs = self._parse_metrics_inputs(batch_outputs)
127
+ metrics(**inputs)
128
+ self.log_dict(metrics, on_step=True, on_epoch=False)
129
+
130
+ def _update_metrics(
131
+ self,
132
+ metrics: metrics_lib.MetricCollection,
133
+ outputs: STEP_OUTPUT,
134
+ dataloader_idx: int = 0,
135
+ ) -> None:
136
+ """Updates the metrics tracker with new data.
137
+
138
+ Here the `outputs` keyword values will be filtered based
139
+ on the signature of all individual metrics and passed only
140
+ to the compatible ones.
141
+
142
+ Args:
143
+ metrics: The desired metrics tracker to update.
144
+ outputs: The outputs of the batch processing step.
145
+ dataloader_idx: The dataloader index.
146
+ """
147
+ inputs = self._parse_metrics_inputs(outputs, dataloader_idx)
148
+ metrics.update(**inputs)
149
+
150
+ def _compute_and_log_metrics(self, metrics: metrics_lib.MetricCollection) -> None:
151
+ """Computes, logs and resets the metrics.
152
+
153
+ Args:
154
+ metrics: The desired metrics tracker to log.
155
+ """
156
+ outputs = metrics.compute()
157
+ self.log_dict(outputs)
158
+ metrics.reset()
159
+
160
+ def _parse_metrics_inputs(
161
+ self,
162
+ outputs: STEP_OUTPUT,
163
+ dataloader_idx: int = 0,
164
+ ) -> Mapping[str, Any]:
165
+ """Parses the arguments for the metrics.
166
+
167
+ When pass to a metrics collection object, the keyword values
168
+ will be filtered based on the signature of all individual
169
+ metrics and passed only to the compatible ones.
170
+
171
+ Args:
172
+ outputs: The outputs of the batch processing step.
173
+ dataloader_idx: The dataloader index.
174
+
175
+ Returns:
176
+ A mapping with the argument name and its value.
177
+ """
178
+ if outputs is None:
179
+ return {}
180
+
181
+ if isinstance(outputs, torch.Tensor):
182
+ outputs = {"loss": outputs}
183
+
184
+ additional_metric_inputs = {
185
+ "preds": outputs.get("predictions"),
186
+ "target": outputs.get("targets"),
187
+ "metadata": outputs.get("metadata"),
188
+ "dataloader_idx": dataloader_idx,
189
+ }
190
+ return {**additional_metric_inputs, **outputs}
@@ -0,0 +1,23 @@
1
+ """Type annotations for model modules."""
2
+
3
+ from typing import Any, Dict, NamedTuple
4
+
5
+ import lightning.pytorch as pl
6
+ import torch
7
+ from torch import nn
8
+
9
+ MODEL_TYPE = nn.Module | pl.LightningModule
10
+ """The expected model type."""
11
+
12
+
13
+ class INPUT_BATCH(NamedTuple):
14
+ """The default input batch data scheme."""
15
+
16
+ data: torch.Tensor
17
+ """The data batch."""
18
+
19
+ targets: torch.Tensor | Dict[str, Any] | None = None
20
+ """The target batch."""
21
+
22
+ metadata: Dict[str, Any] | None = None
23
+ """The associated metadata."""
@@ -0,0 +1,6 @@
1
+ """Utilities and helper functionalities for model modules."""
2
+
3
+ from eva.core.models.modules.utils import grad
4
+ from eva.core.models.modules.utils.batch_postprocess import BatchPostProcess
5
+
6
+ __all__ = ["grad", "BatchPostProcess"]
@@ -0,0 +1,57 @@
1
+ """Batch post-processes module."""
2
+
3
+ import dataclasses
4
+ import functools
5
+ from typing import Callable, List
6
+
7
+ import torch
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+
10
+ Transform = Callable[[torch.Tensor], torch.Tensor]
11
+ """Post-process transform type."""
12
+
13
+
14
+ @dataclasses.dataclass(frozen=True)
15
+ class BatchPostProcess:
16
+ """Batch post-processes transform schema."""
17
+
18
+ targets_transforms: List[Transform] | None = None
19
+ """Holds the common train and evaluation metrics."""
20
+
21
+ predictions_transforms: List[Transform] | None = None
22
+ """Holds the common train and evaluation metrics."""
23
+
24
+ def __call__(self, outputs: STEP_OUTPUT) -> None:
25
+ """Applies the defined list of transforms to the batch output in-place.
26
+
27
+ Note that the transforms are applied only when the input is a dictionary
28
+ and only to its keys of `predictions` and/or `targets`.
29
+
30
+ Args:
31
+ outputs: The batch output of the model module step.
32
+ """
33
+ if not isinstance(outputs, dict):
34
+ return
35
+
36
+ if "targets" in outputs and self.targets_transforms is not None:
37
+ outputs["targets"] = _apply_transforms(
38
+ outputs["targets"], transforms=self.targets_transforms
39
+ )
40
+
41
+ if "predictions" in outputs and self.predictions_transforms is not None:
42
+ outputs["predictions"] = _apply_transforms(
43
+ outputs["predictions"], transforms=self.predictions_transforms
44
+ )
45
+
46
+
47
+ def _apply_transforms(tensor: torch.Tensor, transforms: List[Transform]) -> torch.Tensor:
48
+ """Applies a list of transforms sequentially to a input tensor.
49
+
50
+ Args:
51
+ tensor: The desired tensor to process.
52
+ transforms: The list of transforms to apply to the input tensor.
53
+
54
+ Returns:
55
+ The processed tensor.
56
+ """
57
+ return functools.reduce(lambda tensor, transform: transform(tensor), transforms, tensor)
@@ -0,0 +1,23 @@
1
+ """Gradient related utilities and helper functions."""
2
+
3
+ from torch import nn
4
+
5
+
6
+ def deactivate_requires_grad(module: nn.Module) -> None:
7
+ """Deactivates the `requires_grad` flag for all parameters of a model.
8
+
9
+ Args:
10
+ module: The torch module to deactivate the gradient computation in place.
11
+ """
12
+ for parameter in module.parameters():
13
+ parameter.requires_grad = False
14
+
15
+
16
+ def activate_requires_grad(module: nn.Module) -> None:
17
+ """Activates the `requires_grad` flag for all parameters of a model.
18
+
19
+ Args:
20
+ module: The torch module to deactivate the gradient computation in place.
21
+ """
22
+ for parameter in module.parameters():
23
+ parameter.requires_grad = True
@@ -0,0 +1,6 @@
1
+ """Networks API."""
2
+
3
+ from eva.core.models.networks.mlp import MLP
4
+ from eva.core.models.networks.wrappers import HuggingFaceModel, ModelFromFunction, ONNXModel
5
+
6
+ __all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel", "MLP"]
@@ -0,0 +1,25 @@
1
+ """Utilities and helper functions for models."""
2
+
3
+ from lightning_fabric.utilities import cloud_io
4
+ from loguru import logger
5
+ from torch import nn
6
+
7
+
8
+ def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
9
+ """Loads (local or remote) weights to the model in-place.
10
+
11
+ Args:
12
+ model: The model to load the weights to.
13
+ checkpoint_path: The path to the model weights/checkpoint.
14
+ """
15
+ logger.info(f"Loading '{model.__class__.__name__}' model from checkpoint '{checkpoint_path}'")
16
+
17
+ fs = cloud_io.get_filesystem(checkpoint_path)
18
+ with fs.open(checkpoint_path, "rb") as file:
19
+ checkpoint = cloud_io._load(file, map_location="cpu") # type: ignore
20
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
21
+ checkpoint = checkpoint["state_dict"]
22
+
23
+ model.load_state_dict(checkpoint, strict=True)
24
+
25
+ logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
@@ -0,0 +1,69 @@
1
+ """Multi-layer Perceptron (MLP) implemented in PyTorch."""
2
+
3
+ from typing import Type
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class MLP(nn.Module):
10
+ """A Multi-layer Perceptron (MLP) network."""
11
+
12
+ def __init__(
13
+ self,
14
+ input_size: int,
15
+ output_size: int,
16
+ hidden_layer_sizes: tuple[int, ...] | None = None,
17
+ hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
18
+ output_activation_fn: Type[torch.nn.Module] | None = None,
19
+ dropout: float = 0.0,
20
+ ) -> None:
21
+ """Initializes the MLP.
22
+
23
+ Args:
24
+ input_size: The number of input features.
25
+ output_size: The number of output features.
26
+ hidden_layer_sizes: A list specifying the number of units in each hidden layer.
27
+ dropout: Dropout probability for hidden layers.
28
+ hidden_activation_fn: Activation function to use for hidden layers. Default is ReLU.
29
+ output_activation_fn: Activation function to use for the output layer. Default is None.
30
+ """
31
+ super().__init__()
32
+
33
+ self.input_size = input_size
34
+ self.output_size = output_size
35
+ self.hidden_layer_sizes = hidden_layer_sizes if hidden_layer_sizes is not None else ()
36
+ self.hidden_activation_fn = hidden_activation_fn
37
+ self.output_activation_fn = output_activation_fn
38
+ self.dropout = dropout
39
+
40
+ self._network = self._build_network()
41
+
42
+ def _build_network(self) -> nn.Sequential:
43
+ """Builds the neural network's layers and returns a nn.Sequential container."""
44
+ layers = []
45
+ prev_size = self.input_size
46
+ for size in self.hidden_layer_sizes:
47
+ layers.append(nn.Linear(prev_size, size))
48
+ if self.hidden_activation_fn is not None:
49
+ layers.append(self.hidden_activation_fn())
50
+ if self.dropout > 0:
51
+ layers.append(nn.Dropout(self.dropout))
52
+ prev_size = size
53
+
54
+ layers.append(nn.Linear(prev_size, self.output_size))
55
+ if self.output_activation_fn is not None:
56
+ layers.append(self.output_activation_fn())
57
+
58
+ return nn.Sequential(*layers)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """Defines the forward pass of the MLP.
62
+
63
+ Args:
64
+ x: The input tensor.
65
+
66
+ Returns:
67
+ The output of the network.
68
+ """
69
+ return self._network(x)
@@ -0,0 +1,5 @@
1
+ """Model outputs transforms API."""
2
+
3
+ from eva.core.models.networks.transforms.extract_cls_features import ExtractCLSFeatures
4
+
5
+ __all__ = ["ExtractCLSFeatures"]
@@ -0,0 +1,25 @@
1
+ """Transforms for extracting the CLS output from a model output."""
2
+
3
+ import torch
4
+ from transformers import modeling_outputs
5
+
6
+
7
+ class ExtractCLSFeatures:
8
+ """Extracts the CLS token from a ViT model output."""
9
+
10
+ def __call__(
11
+ self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
12
+ ) -> torch.Tensor:
13
+ """Call method for the transformation.
14
+
15
+ Args:
16
+ tensor: The tensor representing the model output.
17
+ """
18
+ if isinstance(tensor, torch.Tensor):
19
+ transformed_tensor = tensor[:, 0, :]
20
+ elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
21
+ transformed_tensor = tensor.last_hidden_state[:, 0, :]
22
+ else:
23
+ raise ValueError(f"Unsupported type {type(tensor)}")
24
+
25
+ return transformed_tensor
@@ -0,0 +1,8 @@
1
+ """Model Wrappers API."""
2
+
3
+ from eva.core.models.networks.wrappers.base import BaseModel
4
+ from eva.core.models.networks.wrappers.from_function import ModelFromFunction
5
+ from eva.core.models.networks.wrappers.huggingface import HuggingFaceModel
6
+ from eva.core.models.networks.wrappers.onnx import ONNXModel
7
+
8
+ __all__ = ["BaseModel", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"]
@@ -0,0 +1,47 @@
1
+ """Base class for model wrappers."""
2
+
3
+ import abc
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing_extensions import override
9
+
10
+
11
+ class BaseModel(nn.Module):
12
+ """Base class for model wrappers."""
13
+
14
+ def __init__(self, tensor_transforms: Callable | None = None) -> None:
15
+ """Initializes the model.
16
+
17
+ Args:
18
+ tensor_transforms: The transforms to apply to the output
19
+ tensor produced by the model.
20
+ """
21
+ super().__init__()
22
+
23
+ self._output_transforms = tensor_transforms
24
+
25
+ @override
26
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
27
+ tensor = self.model_forward(tensor)
28
+ return self._apply_transforms(tensor)
29
+
30
+ @abc.abstractmethod
31
+ def load_model(self) -> Callable[..., torch.Tensor]:
32
+ """Loads the model."""
33
+ raise NotImplementedError
34
+
35
+ @abc.abstractmethod
36
+ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
+ """Implements the forward pass of the model.
38
+
39
+ Args:
40
+ tensor: The input tensor to the model.
41
+ """
42
+ raise NotImplementedError
43
+
44
+ def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
45
+ if self._output_transforms is not None:
46
+ tensor = self._output_transforms(tensor)
47
+ return tensor
@@ -0,0 +1,58 @@
1
+ """Helper function from models defined with a function."""
2
+
3
+ from typing import Any, Callable, Dict
4
+
5
+ import jsonargparse
6
+ import torch
7
+ from torch import nn
8
+ from typing_extensions import override
9
+
10
+ from eva.core.models.networks import _utils
11
+ from eva.core.models.networks.wrappers import base
12
+
13
+
14
+ class ModelFromFunction(base.BaseModel):
15
+ """Wrapper class for models which are initialized from functions.
16
+
17
+ This is helpful for initializing models in a `.yaml` configuration file.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ path: Callable[..., nn.Module],
23
+ arguments: Dict[str, Any] | None = None,
24
+ checkpoint_path: str | None = None,
25
+ tensor_transforms: Callable | None = None,
26
+ ) -> None:
27
+ """Initializes and constructs the model.
28
+
29
+ Args:
30
+ path: The path to the callable object (class or function).
31
+ arguments: The extra callable function / class arguments.
32
+ checkpoint_path: The path to the checkpoint to load the model
33
+ weights from. This is currently only supported for torch
34
+ model checkpoints. For other formats, the checkpoint loading
35
+ should be handled within the provided callable object in <path>.
36
+ tensor_transforms: The transforms to apply to the output tensor
37
+ produced by the model.
38
+ """
39
+ super().__init__()
40
+
41
+ self._path = path
42
+ self._arguments = arguments
43
+ self._checkpoint_path = checkpoint_path
44
+ self._tensor_transforms = tensor_transforms
45
+
46
+ self._model = self.load_model()
47
+
48
+ @override
49
+ def load_model(self) -> nn.Module:
50
+ class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
51
+ model = class_path(**self._arguments or {})
52
+ if self._checkpoint_path is not None:
53
+ _utils.load_model_weights(model, self._checkpoint_path)
54
+ return model
55
+
56
+ @override
57
+ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
58
+ return self._model(tensor)
@@ -0,0 +1,37 @@
1
+ """Wrappers for HuggingFace `transformers` models."""
2
+
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+ import transformers
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models.networks.wrappers import base
10
+
11
+
12
+ class HuggingFaceModel(base.BaseModel):
13
+ """Wrapper class for loading HuggingFace `transformers` models."""
14
+
15
+ def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
16
+ """Initializes the model.
17
+
18
+ Args:
19
+ model_name_or_path: The model name or path to load the model from.
20
+ This can be a local path or a model name from the `HuggingFace`
21
+ model hub.
22
+ tensor_transforms: The transforms to apply to the output tensor
23
+ produced by the model.
24
+ """
25
+ super().__init__(tensor_transforms=tensor_transforms)
26
+
27
+ self._model_name_or_path = model_name_or_path
28
+ self._model = self.load_model()
29
+
30
+ @override
31
+ def load_model(self) -> Any:
32
+ config = transformers.AutoConfig.from_pretrained(self._model_name_or_path)
33
+ return transformers.AutoModel.from_pretrained(self._model_name_or_path, config=config)
34
+
35
+ @override
36
+ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
37
+ return self._model(tensor)
@@ -0,0 +1,47 @@
1
+ """Wrapper class for ONNX models."""
2
+
3
+ from typing import Any, Callable, Literal
4
+
5
+ import onnxruntime as ort
6
+ import torch
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models.networks.wrappers import base
10
+
11
+
12
+ class ONNXModel(base.BaseModel):
13
+ """Wrapper class for loading ONNX models."""
14
+
15
+ def __init__(
16
+ self,
17
+ path: str,
18
+ device: Literal["cpu", "cuda"] | None = "cpu",
19
+ tensor_transforms: Callable | None = None,
20
+ ):
21
+ """Initializes the model.
22
+
23
+ Args:
24
+ path: The path to the .onnx model file.
25
+ device: The device to run the model on. This can be either "cpu" or "cuda".
26
+ tensor_transforms: The transforms to apply to the output tensor produced by the model.
27
+ """
28
+ super().__init__(tensor_transforms=tensor_transforms)
29
+
30
+ self._path = path
31
+ self._device = device
32
+ self._model = self.load_model()
33
+
34
+ @override
35
+ def load_model(self) -> Any:
36
+ if self._device == "cuda" and not torch.cuda.is_available():
37
+ raise ValueError("Device is set to 'cuda', but CUDA is not available.")
38
+ provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
39
+ return ort.InferenceSession(self._path, providers=[provider])
40
+
41
+ @override
42
+ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
43
+ # TODO: Use IO binding to avoid copying the tensor to CPU.
44
+ # https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
45
+ inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
46
+ outputs = self._model.run(None, inputs)[0]
47
+ return torch.from_numpy(outputs).float().to(tensor.device)
@@ -0,0 +1,6 @@
1
+ """Trainers API."""
2
+
3
+ from eva.core.trainers.functional import infer_model, run_evaluation_session
4
+ from eva.core.trainers.trainer import Trainer
5
+
6
+ __all__ = ["infer_model", "run_evaluation_session", "Trainer"]