kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,76 @@
1
+ """Default metric collection for binary classification tasks."""
2
+
3
+ from torchmetrics import classification
4
+
5
+ from eva.core.metrics import binary_balanced_accuracy, structs
6
+
7
+
8
+ class BinaryClassificationMetrics(structs.MetricCollection):
9
+ """Default metrics for binary classification tasks."""
10
+
11
+ def __init__(
12
+ self,
13
+ threshold: float = 0.5,
14
+ ignore_index: int | None = None,
15
+ prefix: str | None = None,
16
+ postfix: str | None = None,
17
+ ) -> None:
18
+ """Initializes the binary classification metrics.
19
+
20
+ The metrics instantiated here are:
21
+
22
+ - BinaryAUROC
23
+ - BinaryAccuracy
24
+ - BinaryBalancedAccuracy
25
+ - BinaryF1Score
26
+ - BinaryPrecision
27
+ - BinaryRecall
28
+
29
+ Args:
30
+ threshold: Threshold for transforming probability to binary (0,1) predictions
31
+ ignore_index: Specifies a target value that is ignored and does not
32
+ contribute to the metric calculation.
33
+ prefix: A string to append in front of the keys of the output dict.
34
+ postfix: A string to append after the keys of the output dict.
35
+ """
36
+ super().__init__(
37
+ metrics=[
38
+ classification.BinaryAUROC(
39
+ ignore_index=ignore_index,
40
+ ),
41
+ classification.BinaryAccuracy(
42
+ threshold=threshold,
43
+ ignore_index=ignore_index,
44
+ ),
45
+ binary_balanced_accuracy.BinaryBalancedAccuracy(
46
+ threshold=threshold,
47
+ ignore_index=ignore_index,
48
+ ),
49
+ classification.BinaryF1Score(
50
+ threshold=threshold,
51
+ ignore_index=ignore_index,
52
+ ),
53
+ classification.BinaryPrecision(
54
+ threshold=threshold,
55
+ ignore_index=ignore_index,
56
+ ),
57
+ classification.BinaryRecall(
58
+ threshold=threshold,
59
+ ignore_index=ignore_index,
60
+ ),
61
+ ],
62
+ prefix=prefix,
63
+ postfix=postfix,
64
+ compute_groups=[
65
+ [
66
+ "BinaryAccuracy",
67
+ "BinaryBalancedAccuracy",
68
+ "BinaryF1Score",
69
+ "BinaryPrecision",
70
+ "BinaryRecall",
71
+ ],
72
+ [
73
+ "BinaryAUROC",
74
+ ],
75
+ ],
76
+ )
@@ -0,0 +1,80 @@
1
+ """Default metric collection for multiclass classification tasks."""
2
+
3
+ from typing import Literal
4
+
5
+ from torchmetrics import classification
6
+
7
+ from eva.core.metrics import structs
8
+
9
+
10
+ class MulticlassClassificationMetrics(structs.MetricCollection):
11
+ """Default metrics for multi-class classification tasks."""
12
+
13
+ def __init__(
14
+ self,
15
+ num_classes: int,
16
+ average: Literal["macro", "weighted", "none"] = "macro",
17
+ ignore_index: int | None = None,
18
+ prefix: str | None = None,
19
+ postfix: str | None = None,
20
+ ) -> None:
21
+ """Initializes the multi-class classification metrics.
22
+
23
+ The metrics instantiated here are:
24
+
25
+ - MulticlassAccuracy
26
+ - MulticlassPrecision
27
+ - MulticlassRecall
28
+ - MulticlassF1Score
29
+ - MulticlassAUROC
30
+
31
+ Args:
32
+ num_classes: Integer specifying the number of classes.
33
+ average: Defines the reduction that is applied over labels.
34
+ ignore_index: Specifies a target value that is ignored and does not
35
+ contribute to the metric calculation.
36
+ prefix: A string to append in front of the keys of the output dict.
37
+ postfix: A string to append after the keys of the output dict.
38
+ """
39
+ super().__init__(
40
+ metrics=[
41
+ classification.MulticlassAUROC(
42
+ num_classes=num_classes,
43
+ average=average,
44
+ ignore_index=ignore_index,
45
+ ),
46
+ classification.MulticlassAccuracy(
47
+ num_classes=num_classes,
48
+ average=average,
49
+ ignore_index=ignore_index,
50
+ ),
51
+ classification.MulticlassF1Score(
52
+ num_classes=num_classes,
53
+ average=average,
54
+ ignore_index=ignore_index,
55
+ ),
56
+ classification.MulticlassPrecision(
57
+ num_classes=num_classes,
58
+ average=average,
59
+ ignore_index=ignore_index,
60
+ ),
61
+ classification.MulticlassRecall(
62
+ num_classes=num_classes,
63
+ average=average,
64
+ ignore_index=ignore_index,
65
+ ),
66
+ ],
67
+ prefix=prefix,
68
+ postfix=postfix,
69
+ compute_groups=[
70
+ [
71
+ "MulticlassAccuracy",
72
+ "MulticlassF1Score",
73
+ "MulticlassPrecision",
74
+ "MulticlassRecall",
75
+ ],
76
+ [
77
+ "MulticlassAUROC",
78
+ ],
79
+ ],
80
+ )
@@ -0,0 +1,9 @@
1
+ """Core metrics modules API."""
2
+
3
+ from eva.core.metrics.structs.collection import MetricCollection
4
+ from eva.core.metrics.structs.metric import Metric
5
+ from eva.core.metrics.structs.module import MetricModule
6
+ from eva.core.metrics.structs.schemas import MetricsSchema
7
+ from eva.core.metrics.structs.typings import MetricModuleType
8
+
9
+ __all__ = ["MetricCollection", "Metric", "MetricModule", "MetricsSchema", "MetricModuleType"]
@@ -0,0 +1,6 @@
1
+ """Metric collection aggregator."""
2
+
3
+ import torchmetrics
4
+
5
+ MetricCollection = torchmetrics.MetricCollection
6
+ """Defines a metric aggregator object."""
@@ -0,0 +1,6 @@
1
+ """Base class of metrics."""
2
+
3
+ import torchmetrics
4
+
5
+ Metric = torchmetrics.Metric
6
+ """Abstract metric class."""
@@ -0,0 +1,115 @@
1
+ """Metrics module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import nn
6
+
7
+ from eva.core.metrics.structs import collection, schemas
8
+ from eva.core.metrics.structs.typings import MetricModuleType
9
+
10
+
11
+ class MetricModule(nn.Module):
12
+ """The metrics module.
13
+
14
+ Allows to store and keep track of `train`, `val` and `test` metrics.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ train: collection.MetricCollection | None,
20
+ val: collection.MetricCollection | None,
21
+ test: collection.MetricCollection | None,
22
+ ) -> None:
23
+ """Initializes the metrics for the Trainer.
24
+
25
+ Args:
26
+ train: The training metric collection.
27
+ val: The validation metric collection.
28
+ test: The test metric collection.
29
+ """
30
+ super().__init__()
31
+
32
+ self._train = train or self.default_metric_collection
33
+ self._val = val or self.default_metric_collection
34
+ self._test = test or self.default_metric_collection
35
+
36
+ @property
37
+ def default_metric_collection(self) -> collection.MetricCollection:
38
+ """Returns the default metric collection."""
39
+ return collection.MetricCollection([])
40
+
41
+ @classmethod
42
+ def from_metrics(
43
+ cls,
44
+ train: MetricModuleType | None,
45
+ val: MetricModuleType | None,
46
+ test: MetricModuleType | None,
47
+ *,
48
+ separator: str = "/",
49
+ ) -> MetricModule:
50
+ """Initializes a metric module from a list of metrics.
51
+
52
+ Args:
53
+ train: Metrics for the training stage.
54
+ val: Metrics for the validation stage.
55
+ test: Metrics for the test stage.
56
+ separator: The separator between the group name of the metric
57
+ and the metric itself.
58
+ """
59
+ return cls(
60
+ train=_create_collection_from_metrics(train, prefix="train" + separator),
61
+ val=_create_collection_from_metrics(val, prefix="val" + separator),
62
+ test=_create_collection_from_metrics(test, prefix="test" + separator),
63
+ )
64
+
65
+ @classmethod
66
+ def from_schema(
67
+ cls,
68
+ schema: schemas.MetricsSchema,
69
+ *,
70
+ separator: str = "/",
71
+ ) -> MetricModule:
72
+ """Initializes a metric module from the metrics schema.
73
+
74
+ Args:
75
+ schema: The dataclass metric schema.
76
+ separator: The separator between the group name of the metric
77
+ and the metric itself.
78
+ """
79
+ return cls.from_metrics(
80
+ train=schema.training_metrics,
81
+ val=schema.evaluation_metrics,
82
+ test=schema.evaluation_metrics,
83
+ separator=separator,
84
+ )
85
+
86
+ @property
87
+ def training_metrics(self) -> collection.MetricCollection:
88
+ """Returns the metrics of the train dataset."""
89
+ return self._train
90
+
91
+ @property
92
+ def validation_metrics(self) -> collection.MetricCollection:
93
+ """Returns the metrics of the validation dataset."""
94
+ return self._val
95
+
96
+ @property
97
+ def test_metrics(self) -> collection.MetricCollection:
98
+ """Returns the metrics of the test dataset."""
99
+ return self._test
100
+
101
+
102
+ def _create_collection_from_metrics(
103
+ metrics: MetricModuleType | None, *, prefix: str | None = None
104
+ ) -> collection.MetricCollection | None:
105
+ """Create a unique collection from metrics.
106
+
107
+ Args:
108
+ metrics: The desired metrics.
109
+ prefix: A prefix to added to the collection.
110
+
111
+ Returns:
112
+ The resulted metrics collection.
113
+ """
114
+ metrics_collection = collection.MetricCollection(metrics or [], prefix=prefix) # type: ignore
115
+ return metrics_collection.clone()
@@ -0,0 +1,47 @@
1
+ """Metrics related helper schemas."""
2
+
3
+ import dataclasses
4
+
5
+ from eva.core.metrics.structs.typings import MetricModuleType
6
+
7
+
8
+ @dataclasses.dataclass(frozen=True)
9
+ class MetricsSchema:
10
+ """Metrics schema."""
11
+
12
+ common: MetricModuleType | None = None
13
+ """Holds the common train and evaluation metrics."""
14
+
15
+ train: MetricModuleType | None = None
16
+ """The exclusive training metrics."""
17
+
18
+ evaluation: MetricModuleType | None = None
19
+ """The exclusive evaluation metrics."""
20
+
21
+ @property
22
+ def training_metrics(self) -> MetricModuleType | None:
23
+ """Returns the training metics."""
24
+ return self._join_with_common(self.train)
25
+
26
+ @property
27
+ def evaluation_metrics(self) -> MetricModuleType | None:
28
+ """Returns the evaluation metics."""
29
+ return self._join_with_common(self.evaluation)
30
+
31
+ def _join_with_common(self, metrics: MetricModuleType | None) -> MetricModuleType | None:
32
+ """Joins the provided metrics with the common.
33
+
34
+ Note that if there is duplication of metrics between the provided and the common
35
+ (meaning there is the same metric in `metrics` and in `self.common`) both will
36
+ be preserved.
37
+
38
+ Args:
39
+ metrics: The metrics to join.
40
+
41
+ Returns:
42
+ The resulted metrics after joining with the common ones.
43
+ """
44
+ if metrics is None or self.common is None:
45
+ return self.common or metrics
46
+
47
+ return [self.common, metrics] # type: ignore
@@ -0,0 +1,15 @@
1
+ """Metric typings."""
2
+
3
+ from typing import Dict, Sequence, Union
4
+
5
+ from eva.core.metrics.structs import collection, metric
6
+
7
+ BaseMetricModuleType = Union[metric.Metric, collection.MetricCollection]
8
+ """The base module metric type."""
9
+
10
+ MetricModuleType = Union[
11
+ BaseMetricModuleType,
12
+ Sequence[BaseMetricModuleType],
13
+ Dict[str, BaseMetricModuleType],
14
+ ]
15
+ """The module metric type."""
@@ -0,0 +1,13 @@
1
+ """Models API."""
2
+
3
+ from eva.core.models.modules import HeadModule, InferenceModule
4
+ from eva.core.models.networks import MLP, HuggingFaceModel, ModelFromFunction, ONNXModel
5
+
6
+ __all__ = [
7
+ "HeadModule",
8
+ "InferenceModule",
9
+ "MLP",
10
+ "HuggingFaceModel",
11
+ "ModelFromFunction",
12
+ "ONNXModel",
13
+ ]
@@ -0,0 +1,7 @@
1
+ """Model Modules API."""
2
+
3
+ from eva.core.models.modules.head import HeadModule
4
+ from eva.core.models.modules.inference import InferenceModule
5
+ from eva.core.models.modules.module import ModelModule
6
+
7
+ __all__ = ["HeadModule", "ModelModule", "InferenceModule"]
@@ -0,0 +1,113 @@
1
+ """"Neural Network Head Module."""
2
+
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+ from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
7
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from torch import optim
9
+ from torch.optim import lr_scheduler
10
+ from typing_extensions import override
11
+
12
+ from eva.core.metrics import structs as metrics_lib
13
+ from eva.core.models.modules import module
14
+ from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
15
+ from eva.core.models.modules.utils import batch_postprocess, grad
16
+
17
+
18
+ class HeadModule(module.ModelModule):
19
+ """Neural Net Head Module for training on features.
20
+
21
+ It can be used for supervised (mini-batch) stochastic gradient descent
22
+ downstream tasks such as classification, regression and segmentation.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ head: MODEL_TYPE,
28
+ criterion: Callable[..., torch.Tensor],
29
+ backbone: MODEL_TYPE | None = None,
30
+ optimizer: OptimizerCallable = optim.Adam,
31
+ lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
32
+ metrics: metrics_lib.MetricsSchema | None = None,
33
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
34
+ ) -> None:
35
+ """Initializes the neural net head module.
36
+
37
+ Args:
38
+ head: The neural network that would be trained on the features.
39
+ criterion: The loss function to use.
40
+ backbone: The feature extractor. If `None`, it will be expected
41
+ that the input batch returns the features directly.
42
+ optimizer: The optimizer to use.
43
+ lr_scheduler: The learning rate scheduler to use.
44
+ metrics: The metric groups to track.
45
+ postprocess: A list of helper functions to apply after the
46
+ loss and before the metrics calculation to the model
47
+ predictions and targets.
48
+ """
49
+ super().__init__(metrics=metrics, postprocess=postprocess)
50
+
51
+ self.head = head
52
+ self.criterion = criterion
53
+ self.backbone = backbone
54
+ self.optimizer = optimizer
55
+ self.lr_scheduler = lr_scheduler
56
+
57
+ @override
58
+ def configure_optimizers(self) -> Any:
59
+ parameters = list(self.head.parameters())
60
+ optimizer = self.optimizer(parameters)
61
+ lr_scheduler = self.lr_scheduler(optimizer)
62
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
63
+
64
+ @override
65
+ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
66
+ features = tensor if self.backbone is None else self.backbone(tensor)
67
+ return self.head(features).squeeze(-1)
68
+
69
+ @override
70
+ def on_fit_start(self) -> None:
71
+ if self.backbone is not None:
72
+ grad.deactivate_requires_grad(self.backbone)
73
+
74
+ @override
75
+ def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
76
+ return self._batch_step(batch)
77
+
78
+ @override
79
+ def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
80
+ return self._batch_step(batch)
81
+
82
+ @override
83
+ def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
84
+ return self._batch_step(batch)
85
+
86
+ @override
87
+ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
88
+ tensor = INPUT_BATCH(*batch).data
89
+ return tensor if self.backbone is None else self.backbone(tensor)
90
+
91
+ @override
92
+ def on_fit_end(self) -> None:
93
+ if self.backbone is not None:
94
+ grad.activate_requires_grad(self.backbone)
95
+
96
+ def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
97
+ """Performs a model forward step and calculates the loss.
98
+
99
+ Args:
100
+ batch: The desired batch to process.
101
+
102
+ Returns:
103
+ The batch step output.
104
+ """
105
+ data, targets, metadata = INPUT_BATCH(*batch)
106
+ predictions = self(data)
107
+ loss = self.criterion(predictions, targets)
108
+ return {
109
+ "loss": loss,
110
+ "targets": targets,
111
+ "predictions": predictions,
112
+ "metadata": metadata,
113
+ }
@@ -0,0 +1,37 @@
1
+ """Model inference module."""
2
+
3
+ import torch
4
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
5
+ from typing_extensions import override
6
+
7
+ from eva.core.models.modules import module
8
+ from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
9
+
10
+
11
+ class InferenceModule(module.ModelModule):
12
+ """An lightweight model module to perform inference."""
13
+
14
+ def __init__(self, backbone: MODEL_TYPE) -> None:
15
+ """Initializes the module.
16
+
17
+ Args:
18
+ backbone: The network to be used for inference.
19
+ """
20
+ super().__init__(metrics=None)
21
+
22
+ self.backbone = backbone
23
+
24
+ @override
25
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
26
+ return self.backbone(tensor)
27
+
28
+ @override
29
+ def predict_step(
30
+ self,
31
+ batch: INPUT_BATCH,
32
+ batch_idx: int,
33
+ dataloader_idx: int = 0,
34
+ ) -> STEP_OUTPUT:
35
+ data, *_ = INPUT_BATCH(*batch)
36
+ predictions = self(data)
37
+ return predictions