kaiko-eva 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/functional.py +40 -43
  17. eva/core/utils/factory.py +66 -0
  18. eva/core/utils/registry.py +42 -0
  19. eva/core/utils/requirements.py +26 -0
  20. eva/language/__init__.py +13 -0
  21. eva/language/data/__init__.py +5 -0
  22. eva/language/data/datasets/__init__.py +9 -0
  23. eva/language/data/datasets/classification/__init__.py +7 -0
  24. eva/language/data/datasets/classification/base.py +63 -0
  25. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  26. eva/language/data/datasets/language.py +13 -0
  27. eva/language/models/__init__.py +25 -0
  28. eva/language/models/modules/__init__.py +5 -0
  29. eva/language/models/modules/text.py +85 -0
  30. eva/language/models/modules/typings.py +16 -0
  31. eva/language/models/wrappers/__init__.py +11 -0
  32. eva/language/models/wrappers/huggingface.py +69 -0
  33. eva/language/models/wrappers/litellm.py +77 -0
  34. eva/language/models/wrappers/vllm.py +149 -0
  35. eva/language/utils/__init__.py +5 -0
  36. eva/language/utils/str_to_int_tensor.py +95 -0
  37. eva/vision/data/dataloaders/__init__.py +2 -1
  38. eva/vision/data/dataloaders/worker_init.py +35 -0
  39. eva/vision/data/datasets/__init__.py +5 -5
  40. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  41. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  42. eva/vision/data/datasets/segmentation/consep.py +5 -4
  43. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  44. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  45. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  46. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  47. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  48. eva/vision/data/transforms/__init__.py +11 -2
  49. eva/vision/data/transforms/base/__init__.py +5 -0
  50. eva/vision/data/transforms/base/monai.py +27 -0
  51. eva/vision/data/transforms/common/__init__.py +2 -1
  52. eva/vision/data/transforms/common/squeeze.py +24 -0
  53. eva/vision/data/transforms/croppad/__init__.py +4 -0
  54. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  56. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  57. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  58. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  59. eva/vision/models/modules/semantic_segmentation.py +18 -7
  60. eva/vision/models/networks/backbones/__init__.py +2 -3
  61. eva/vision/models/networks/backbones/_utils.py +1 -1
  62. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  63. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  64. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  65. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  66. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  67. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  68. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  72. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  73. eva/vision/models/networks/backbones/registry.py +2 -44
  74. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  75. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  76. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  77. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  78. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  80. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  81. eva/vision/models/wrappers/from_registry.py +14 -9
  82. eva/vision/models/wrappers/from_timm.py +6 -5
  83. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/METADATA +10 -2
  84. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/RECORD +88 -57
  85. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/WHEEL +1 -1
  86. eva/vision/data/datasets/segmentation/lits.py +0 -199
  87. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  88. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  89. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/entry_points.txt +0 -0
  90. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ """Factory classes."""
2
+
3
+ import inspect
4
+ from typing import Any, Dict, Generic, Type, TypeVar
5
+
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.utils.registry import Registry, RegistryItem
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ class Factory(Generic[T]):
15
+ """A base factory class for instantiating registry items of a specific type."""
16
+
17
+ def __new__(cls, registry: Registry, name: str, init_args: dict, expected_type: Type[T]) -> T:
18
+ """Creates the appropriate instance based on registry entry.
19
+
20
+ Args:
21
+ registry: The registry containing the items to instantiate.
22
+ name: Name of the registry item to instantiate.
23
+ init_args: The arguments to pass to the constructor of the registry item.
24
+ expected_type: The expected type of the instantiated object.
25
+ """
26
+ if name not in registry.entries():
27
+ raise ValueError(
28
+ f"Invalid name: {name}. Please choose one "
29
+ f"of the following: {registry.entries()}"
30
+ )
31
+
32
+ registry_item = registry.get(name)
33
+ filtered_kwargs = _filter_kwargs(registry_item, init_args)
34
+ instance = registry_item(**filtered_kwargs)
35
+
36
+ if not isinstance(instance, expected_type):
37
+ raise TypeError(f"Expected an instance of {expected_type}, but got {type(instance)}.")
38
+ return instance
39
+
40
+
41
+ class ModuleFactory(Factory[nn.Module]):
42
+ """Factory class for instantiating nn.Module instances from a registry."""
43
+
44
+ @override
45
+ def __new__(cls, registry: Registry, name: str, init_args: dict) -> nn.Module:
46
+ return super().__new__(cls, registry, name, init_args, nn.Module)
47
+
48
+
49
+ def _filter_kwargs(registry_item: RegistryItem, kwargs: dict) -> Dict[str, Any]:
50
+ """Filters the given keyword arguments to match the signature of a given class or method.
51
+
52
+ Args:
53
+ registry_item: The class or method from the registry whose
54
+ signature should be used for filtering.
55
+ kwargs: The keyword arguments to filter.
56
+
57
+ Returns:
58
+ A dictionary containing only the valid keyword arguments that match
59
+ the callable's parameters.
60
+ """
61
+ if inspect.isclass(registry_item):
62
+ signature = inspect.signature(registry_item.__init__)
63
+ else:
64
+ signature = inspect.signature(registry_item)
65
+
66
+ return {k: v for k, v in kwargs.items() if k in signature.parameters}
@@ -0,0 +1,42 @@
1
+ """Registry for classes and methods."""
2
+
3
+ from typing import Any, Callable, Dict, List, Type, Union
4
+
5
+ RegistryItem = Union[Type[Any], Callable[..., Any]]
6
+
7
+
8
+ class Registry:
9
+ """A registry to store and access classes and methods by a unique key."""
10
+
11
+ def __init__(self) -> None:
12
+ """Initializes the registry class."""
13
+ self._registry: Dict[str, RegistryItem] = {}
14
+
15
+ def register(self, key: str, /) -> Callable[[RegistryItem], RegistryItem]:
16
+ """A decorator to register a class or method with a unique key.
17
+
18
+ Args:
19
+ key: The key to register the class or method under.
20
+
21
+ Returns:
22
+ A decorator that registers the class or method in the registry.
23
+ """
24
+
25
+ def wrapper(obj: RegistryItem) -> RegistryItem:
26
+ if key in self.entries():
27
+ raise ValueError(f"Entry {key} is already registered.")
28
+
29
+ self._registry[key] = obj
30
+ return obj
31
+
32
+ return wrapper
33
+
34
+ def get(self, name: str) -> RegistryItem:
35
+ """Gets the class or method from the registry."""
36
+ if name not in self._registry:
37
+ raise ValueError(f"Item {name} not found in the registry.")
38
+ return self._registry[name]
39
+
40
+ def entries(self) -> List[str]:
41
+ """List all items in the registry."""
42
+ return list(self._registry.keys())
@@ -0,0 +1,26 @@
1
+ """Utility functions related to package requirements."""
2
+
3
+ import importlib
4
+ from typing import Dict
5
+
6
+ from packaging import version
7
+
8
+
9
+ def check_dependencies(requirements: Dict[str, str]) -> None:
10
+ """Check installed package versions against requirements dict.
11
+
12
+ Args:
13
+ requirements: A dictionary where keys are package names and
14
+ values are minimum required versions.
15
+
16
+ Raises:
17
+ ImportError: If any package does not meet the minimum required version.
18
+ """
19
+ for package, min_version in requirements.items():
20
+ module = importlib.import_module(package)
21
+ actual = getattr(module, "__version__", None)
22
+ if actual and not (version.parse(actual) >= version.parse(min_version)):
23
+ raise ImportError(
24
+ f"Package '{package}' version {actual} does not meet "
25
+ f"the minimum required version {min_version}."
26
+ )
@@ -0,0 +1,13 @@
1
+ """eva language API."""
2
+
3
+ try:
4
+ from eva.language.data import datasets
5
+ except ImportError as e:
6
+ msg = (
7
+ "eva language requirements are not installed.\n\n"
8
+ "Please pip install as follows:\n"
9
+ ' python -m pip install "kaiko-eva[language]" --upgrade'
10
+ )
11
+ raise ImportError(str(e) + "\n\n" + msg) from e
12
+
13
+ __all__ = ["datasets"]
@@ -0,0 +1,5 @@
1
+ """Language data API."""
2
+
3
+ from eva.language.data import datasets
4
+
5
+ __all__ = ["datasets"]
@@ -0,0 +1,9 @@
1
+ """Language Datasets API."""
2
+
3
+ from eva.language.data.datasets.classification import PubMedQA
4
+ from eva.language.data.datasets.language import LanguageDataset
5
+
6
+ __all__ = [
7
+ "PubMedQA",
8
+ "LanguageDataset",
9
+ ]
@@ -0,0 +1,7 @@
1
+ """Text classification datasets API."""
2
+
3
+ from eva.language.data.datasets.classification.pubmedqa import PubMedQA
4
+
5
+ __all__ = [
6
+ "PubMedQA",
7
+ ]
@@ -0,0 +1,63 @@
1
+ """Base for text classification datasets."""
2
+
3
+ import abc
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ import torch
7
+ from typing_extensions import override
8
+
9
+ from eva.language.data.datasets.language import LanguageDataset
10
+
11
+
12
+ class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]]], abc.ABC):
13
+ """Text classification abstract dataset."""
14
+
15
+ def __init__(self) -> None:
16
+ """Initializes the text classification dataset."""
17
+ super().__init__()
18
+
19
+ @property
20
+ def classes(self) -> List[str] | None:
21
+ """Returns list of class names."""
22
+
23
+ @property
24
+ def class_to_idx(self) -> Dict[str, int] | None:
25
+ """Returns class name to index mapping."""
26
+
27
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
28
+ """Returns the dataset metadata.
29
+
30
+ Args:
31
+ index: The index of the data sample.
32
+
33
+ Returns:
34
+ The sample metadata.
35
+ """
36
+
37
+ @abc.abstractmethod
38
+ def load_text(self, index: int) -> str:
39
+ """Returns the text content.
40
+
41
+ Args:
42
+ index: The index of the data sample.
43
+
44
+ Returns:
45
+ The text content.
46
+ """
47
+ raise NotImplementedError
48
+
49
+ @abc.abstractmethod
50
+ def load_target(self, index: int) -> torch.Tensor:
51
+ """Returns the target label.
52
+
53
+ Args:
54
+ index: The index of the data sample.
55
+
56
+ Returns:
57
+ The target label.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ @override
62
+ def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
63
+ return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
@@ -0,0 +1,149 @@
1
+ """PubMedQA dataset class."""
2
+
3
+ import os
4
+ import random
5
+ from typing import Dict, List, Literal
6
+
7
+ import torch
8
+ from datasets import Dataset, load_dataset, load_from_disk
9
+ from loguru import logger
10
+ from typing_extensions import override
11
+
12
+ from eva.language.data.datasets.classification import base
13
+
14
+
15
+ class PubMedQA(base.TextClassification):
16
+ """Dataset class for PubMedQA question answering task."""
17
+
18
+ _license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
19
+ """Dataset license."""
20
+
21
+ def __init__(
22
+ self,
23
+ root: str | None = None,
24
+ split: Literal["train", "val", "test"] | None = None,
25
+ download: bool = False,
26
+ max_samples: int | None = None,
27
+ ) -> None:
28
+ """Initialize the PubMedQA dataset.
29
+
30
+ Args:
31
+ root: Directory to cache the dataset. If None, no local caching is used.
32
+ split: Valid splits among ["train", "val", "test"].
33
+ If None, it will use "train+test+validation".
34
+ download: Whether to download the dataset if not found locally. Default is False.
35
+ max_samples: Maximum number of samples to use. If None, use all samples.
36
+ """
37
+ super().__init__()
38
+
39
+ self._root = root
40
+ self._split = split
41
+ self._download = download
42
+ self._max_samples = max_samples
43
+
44
+ def _load_dataset(self, dataset_path: str | None) -> Dataset:
45
+ """Loads the PubMedQA dataset from the local cache or downloads it.
46
+
47
+ Args:
48
+ dataset_path: The path to the local cache (may be None).
49
+
50
+ Returns:
51
+ The loaded dataset object.
52
+ """
53
+ dataset_name = "bigbio/pubmed_qa"
54
+ config_name = "pubmed_qa_labeled_fold0_source"
55
+ split = (self._split or "train+test+validation") if self._split != "val" else "validation"
56
+
57
+ if self._download:
58
+ logger.info("Downloading dataset from HuggingFace Hub")
59
+ raw_dataset = load_dataset(
60
+ dataset_name,
61
+ name=config_name,
62
+ split=split,
63
+ trust_remote_code=True,
64
+ download_mode="reuse_dataset_if_exists",
65
+ )
66
+ if dataset_path:
67
+ raw_dataset.save_to_disk(dataset_path) # type: ignore
68
+ logger.info(f"Dataset saved to: {dataset_path}")
69
+ else:
70
+ if not dataset_path or not os.path.exists(dataset_path):
71
+ raise ValueError(
72
+ "Dataset path not found. Set download=True or provide a valid root path."
73
+ )
74
+
75
+ logger.info(f"Loading dataset from: {dataset_path}")
76
+ raw_dataset = load_from_disk(dataset_path)
77
+
78
+ return raw_dataset # type: ignore
79
+
80
+ @override
81
+ def prepare_data(self) -> None:
82
+ """Downloads and prepares the PubMedQA dataset.
83
+
84
+ If `self._root` is None, the dataset is used directly from HuggingFace.
85
+ Otherwise, it checks if the dataset is already cached in `self._root`.
86
+ If not cached, it downloads the dataset into `self._root`.
87
+ """
88
+ dataset_path = None
89
+
90
+ if self._root:
91
+ dataset_path = self._root
92
+ os.makedirs(self._root, exist_ok=True)
93
+
94
+ try:
95
+ self.dataset = self._load_dataset(dataset_path)
96
+ if self._max_samples is not None and len(self.dataset) > self._max_samples:
97
+ logger.info(
98
+ f"Subsampling dataset from {len(self.dataset)} to {self._max_samples} samples"
99
+ )
100
+ random.seed(42)
101
+ indices = random.sample(range(len(self.dataset)), self._max_samples)
102
+ self.dataset = self.dataset.select(indices)
103
+ except Exception as e:
104
+ raise RuntimeError(f"Failed to prepare dataset: {e}") from e
105
+
106
+ @property
107
+ @override
108
+ def classes(self) -> List[str]:
109
+ return ["no", "yes", "maybe"]
110
+
111
+ @property
112
+ @override
113
+ def class_to_idx(self) -> Dict[str, int]:
114
+ return {"no": 0, "yes": 1, "maybe": 2}
115
+
116
+ @override
117
+ def load_text(self, index: int) -> str:
118
+ if index < 0 or index >= len(self.dataset):
119
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
120
+ sample = dict(self.dataset[index])
121
+ return f"Question: {sample['QUESTION']}\nContext: " + " ".join(sample["CONTEXTS"])
122
+
123
+ @override
124
+ def load_target(self, index: int) -> torch.Tensor:
125
+ if index < 0 or index >= len(self.dataset):
126
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
127
+ return torch.tensor(
128
+ self.class_to_idx[self.dataset[index]["final_decision"]], dtype=torch.long
129
+ )
130
+
131
+ @override
132
+ def load_metadata(self, index: int) -> Dict[str, str]:
133
+ sample = self.dataset[index]
134
+ return {
135
+ "year": sample.get("YEAR") or "",
136
+ "labels": sample.get("LABELS") or "",
137
+ "meshes": sample.get("MESHES") or "",
138
+ "long_answer": sample.get("LONG_ANSWER") or "",
139
+ "reasoning_required": sample.get("reasoning_required_pred") or "",
140
+ "reasoning_free": sample.get("reasoning_free_pred") or "",
141
+ }
142
+
143
+ @override
144
+ def __len__(self) -> int:
145
+ return len(self.dataset)
146
+
147
+ def _print_license(self) -> None:
148
+ """Prints the dataset license."""
149
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,13 @@
1
+ """Language Dataset base class."""
2
+
3
+ import abc
4
+ from typing import Generic, TypeVar
5
+
6
+ from eva.core.data.datasets import base
7
+
8
+ DataSample = TypeVar("DataSample")
9
+ """The data sample type."""
10
+
11
+
12
+ class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
+ """Base dataset class for text tasks."""
@@ -0,0 +1,25 @@
1
+ """Language Models API."""
2
+
3
+ from eva.language.models import modules, wrappers
4
+ from eva.language.models.modules import TextModule
5
+ from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel
6
+
7
+ try:
8
+ from eva.language.models.wrappers import VLLMTextModel
9
+
10
+ __all__ = [
11
+ "modules",
12
+ "wrappers",
13
+ "TextModule",
14
+ "HuggingFaceTextModel",
15
+ "LiteLLMTextModel",
16
+ "VLLMTextModel",
17
+ ]
18
+ except ImportError:
19
+ __all__ = [
20
+ "modules",
21
+ "wrappers",
22
+ "TextModule",
23
+ "HuggingFaceTextModel",
24
+ "LiteLLMTextModel",
25
+ ]
@@ -0,0 +1,5 @@
1
+ """Language Networks API."""
2
+
3
+ from eva.language.models.modules.text import TextModule
4
+
5
+ __all__ = ["TextModule"]
@@ -0,0 +1,85 @@
1
+ """LLM Text Module for Inference."""
2
+
3
+ from typing import Any, List
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from loguru import logger
7
+ from torch import nn
8
+ from typing_extensions import override
9
+
10
+ from eva.core.metrics import structs as metrics_lib
11
+ from eva.core.models.modules import module
12
+ from eva.core.models.modules.utils import batch_postprocess
13
+ from eva.language.models.modules.typings import TEXT_BATCH
14
+
15
+
16
+ class TextModule(module.ModelModule):
17
+ """Text-based LLM module for inference.
18
+
19
+ Uses LLM wrappers for text generation and supports evaluation using
20
+ configurable metrics and post-processing transforms.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: nn.Module,
26
+ prompt: str,
27
+ metrics: metrics_lib.MetricsSchema | None = None,
28
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
29
+ ) -> None:
30
+ """Initializes the text inference module.
31
+
32
+ Args:
33
+ model: An LLM wrapper (PyTorch-compatible) for text generation.
34
+ prompt: The prompt to use for generating text.
35
+ metrics: Metrics schema for evaluation.
36
+ postprocess: A helper function to post-process model outputs before evaluation.
37
+ """
38
+ super().__init__(metrics=metrics, postprocess=postprocess)
39
+
40
+ self.model = model
41
+ self.prompt = prompt
42
+
43
+ @override
44
+ def forward(self, prompts: List[str], *args: Any, **kwargs: Any) -> List[str]:
45
+ """Generates text responses for a batch of prompts.
46
+
47
+ Args:
48
+ prompts: List of input texts to generate responses.
49
+ args: Additional arguments.
50
+ kwargs: Additional keyword arguments.
51
+
52
+ Returns:
53
+ List of generated responses.
54
+ """
55
+ return self.model(prompts)
56
+
57
+ @override
58
+ def validation_step(self, batch: TEXT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
59
+ """Validation step that runs batch inference and evaluates metrics.
60
+
61
+ Args:
62
+ batch: An input batch.
63
+ args: Additional arguments.
64
+ kwargs: Additional keyword arguments.
65
+
66
+ Returns:
67
+ Dictionary with predictions, ground truth, and evaluation metrics.
68
+ """
69
+ return self._batch_step(batch)
70
+
71
+ def _batch_step(self, batch: TEXT_BATCH) -> STEP_OUTPUT:
72
+ """Runs inference on a batch and evaluates model predictions.
73
+
74
+ Args:
75
+ batch: Input batch containing data, targets, and metadata.
76
+
77
+ Returns:
78
+ Dictionary with predictions, ground truth, and evaluation metrics.
79
+ """
80
+ data, targets, metadata = batch
81
+ messages = [str(d) + "\n" + self.prompt for d in data]
82
+ predictions = self(messages)
83
+ logger.debug(f"Predictions: {predictions}")
84
+ logger.debug(f"Targets: {targets}")
85
+ return {"predictions": predictions, "targets": targets, "metadata": metadata}
@@ -0,0 +1,16 @@
1
+ """Type annotations for language model modules."""
2
+
3
+ from typing import Any, Dict, List, NamedTuple
4
+
5
+
6
+ class TEXT_BATCH(NamedTuple):
7
+ """Text-based input batch data scheme."""
8
+
9
+ data: List[str]
10
+ """The text data batch."""
11
+
12
+ targets: List[str] | None = None
13
+ """The target text batch."""
14
+
15
+ metadata: Dict[str, Any] | None = None
16
+ """The associated metadata."""
@@ -0,0 +1,11 @@
1
+ """Language Model Wrappers API."""
2
+
3
+ from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
4
+ from eva.language.models.wrappers.litellm import LiteLLMTextModel
5
+
6
+ try:
7
+ from eva.language.models.wrappers.vllm import VLLMTextModel
8
+
9
+ __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
10
+ except ImportError:
11
+ __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel"]
@@ -0,0 +1,69 @@
1
+ """LLM wrapper for HuggingFace `transformers` models."""
2
+
3
+ from typing import Any, Dict, List, Literal
4
+
5
+ from transformers.pipelines import pipeline
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+
10
+
11
+ class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
12
+ """Wrapper class for loading HuggingFace `transformers` models using pipelines."""
13
+
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ task: Literal["text-generation"] = "text-generation",
18
+ model_kwargs: Dict[str, Any] | None = None,
19
+ generation_kwargs: Dict[str, Any] | None = None,
20
+ ) -> None:
21
+ """Initializes the model.
22
+
23
+ Args:
24
+ model_name_or_path: The model name or path to load the model from.
25
+ This can be a local path or a model name from the `HuggingFace`
26
+ model hub.
27
+ task: The pipeline task. Defaults to "text-generation".
28
+ model_kwargs: Additional arguments for configuring the pipeline.
29
+ generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
30
+ """
31
+ super().__init__()
32
+
33
+ self._model_name_or_path = model_name_or_path
34
+ self._task = task
35
+ self._model_kwargs = model_kwargs or {}
36
+ self._generation_kwargs = generation_kwargs or {}
37
+
38
+ self.load_model()
39
+
40
+ @override
41
+ def load_model(self) -> None:
42
+ """Loads the model as a Hugging Face pipeline."""
43
+ self._pipeline = pipeline(
44
+ task=self._task,
45
+ model=self._model_name_or_path,
46
+ trust_remote_code=True,
47
+ **self._model_kwargs,
48
+ )
49
+
50
+ @override
51
+ def model_forward(self, prompts: List[str]) -> List[str]:
52
+ """Generates text using the pipeline.
53
+
54
+ Args:
55
+ prompts: The input prompts for the model.
56
+
57
+ Returns:
58
+ The generated text as a string.
59
+ """
60
+ outputs = self._pipeline(prompts, return_full_text=False, **self._generation_kwargs)
61
+ if outputs is None:
62
+ raise ValueError("Outputs from the model are None.")
63
+ results = []
64
+ for output in outputs:
65
+ if isinstance(output, list):
66
+ results.append(output[0]["generated_text"]) # type: ignore
67
+ else:
68
+ results.append(output["generated_text"]) # type: ignore
69
+ return results