kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/__init__.py +2 -1
- eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/core/data/dataloaders/collate_fn/collate.py +24 -0
- eva/core/data/dataloaders/dataloader.py +4 -0
- eva/core/interface/interface.py +34 -1
- eva/core/metrics/defaults/classification/multiclass.py +45 -35
- eva/core/models/modules/__init__.py +2 -1
- eva/core/models/modules/scheduler.py +51 -0
- eva/core/models/transforms/extract_cls_features.py +1 -1
- eva/core/models/transforms/extract_patch_features.py +1 -1
- eva/core/models/wrappers/base.py +17 -14
- eva/core/models/wrappers/from_function.py +5 -4
- eva/core/models/wrappers/from_torchhub.py +5 -6
- eva/core/models/wrappers/huggingface.py +8 -5
- eva/core/models/wrappers/onnx.py +4 -4
- eva/core/trainers/functional.py +40 -43
- eva/core/utils/factory.py +66 -0
- eva/core/utils/registry.py +42 -0
- eva/core/utils/requirements.py +26 -0
- eva/language/__init__.py +13 -0
- eva/language/data/__init__.py +5 -0
- eva/language/data/datasets/__init__.py +9 -0
- eva/language/data/datasets/classification/__init__.py +7 -0
- eva/language/data/datasets/classification/base.py +63 -0
- eva/language/data/datasets/classification/pubmedqa.py +149 -0
- eva/language/data/datasets/language.py +13 -0
- eva/language/models/__init__.py +25 -0
- eva/language/models/modules/__init__.py +5 -0
- eva/language/models/modules/text.py +85 -0
- eva/language/models/modules/typings.py +16 -0
- eva/language/models/wrappers/__init__.py +11 -0
- eva/language/models/wrappers/huggingface.py +69 -0
- eva/language/models/wrappers/litellm.py +77 -0
- eva/language/models/wrappers/vllm.py +149 -0
- eva/language/utils/__init__.py +5 -0
- eva/language/utils/str_to_int_tensor.py +95 -0
- eva/vision/data/dataloaders/__init__.py +2 -1
- eva/vision/data/dataloaders/worker_init.py +35 -0
- eva/vision/data/datasets/__init__.py +5 -5
- eva/vision/data/datasets/segmentation/__init__.py +4 -4
- eva/vision/data/datasets/segmentation/btcv.py +3 -0
- eva/vision/data/datasets/segmentation/consep.py +5 -4
- eva/vision/data/datasets/segmentation/lits17.py +231 -0
- eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
- eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
- eva/vision/data/transforms/__init__.py +11 -2
- eva/vision/data/transforms/base/__init__.py +5 -0
- eva/vision/data/transforms/base/monai.py +27 -0
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/squeeze.py +24 -0
- eva/vision/data/transforms/croppad/__init__.py +4 -0
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
- eva/vision/models/modules/semantic_segmentation.py +18 -7
- eva/vision/models/networks/backbones/__init__.py +2 -3
- eva/vision/models/networks/backbones/_utils.py +1 -1
- eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
- eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
- eva/vision/models/networks/backbones/pathology/histai.py +3 -3
- eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
- eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
- eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
- eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
- eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
- eva/vision/models/networks/backbones/pathology/paige.py +3 -3
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
- eva/vision/models/networks/backbones/radiology/voco.py +5 -5
- eva/vision/models/networks/backbones/registry.py +2 -44
- eva/vision/models/networks/backbones/timm/backbones.py +2 -2
- eva/vision/models/networks/backbones/universal/__init__.py +8 -1
- eva/vision/models/networks/backbones/universal/vit.py +53 -3
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
- eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
- eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
- eva/vision/models/wrappers/from_registry.py +14 -9
- eva/vision/models/wrappers/from_timm.py +6 -5
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/segmentation/lits.py +0 -199
- eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
- /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Factory classes."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from typing import Any, Dict, Generic, Type, TypeVar
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.utils.registry import Registry, RegistryItem
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Factory(Generic[T]):
|
|
15
|
+
"""A base factory class for instantiating registry items of a specific type."""
|
|
16
|
+
|
|
17
|
+
def __new__(cls, registry: Registry, name: str, init_args: dict, expected_type: Type[T]) -> T:
|
|
18
|
+
"""Creates the appropriate instance based on registry entry.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
registry: The registry containing the items to instantiate.
|
|
22
|
+
name: Name of the registry item to instantiate.
|
|
23
|
+
init_args: The arguments to pass to the constructor of the registry item.
|
|
24
|
+
expected_type: The expected type of the instantiated object.
|
|
25
|
+
"""
|
|
26
|
+
if name not in registry.entries():
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Invalid name: {name}. Please choose one "
|
|
29
|
+
f"of the following: {registry.entries()}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
registry_item = registry.get(name)
|
|
33
|
+
filtered_kwargs = _filter_kwargs(registry_item, init_args)
|
|
34
|
+
instance = registry_item(**filtered_kwargs)
|
|
35
|
+
|
|
36
|
+
if not isinstance(instance, expected_type):
|
|
37
|
+
raise TypeError(f"Expected an instance of {expected_type}, but got {type(instance)}.")
|
|
38
|
+
return instance
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModuleFactory(Factory[nn.Module]):
|
|
42
|
+
"""Factory class for instantiating nn.Module instances from a registry."""
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def __new__(cls, registry: Registry, name: str, init_args: dict) -> nn.Module:
|
|
46
|
+
return super().__new__(cls, registry, name, init_args, nn.Module)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _filter_kwargs(registry_item: RegistryItem, kwargs: dict) -> Dict[str, Any]:
|
|
50
|
+
"""Filters the given keyword arguments to match the signature of a given class or method.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
registry_item: The class or method from the registry whose
|
|
54
|
+
signature should be used for filtering.
|
|
55
|
+
kwargs: The keyword arguments to filter.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A dictionary containing only the valid keyword arguments that match
|
|
59
|
+
the callable's parameters.
|
|
60
|
+
"""
|
|
61
|
+
if inspect.isclass(registry_item):
|
|
62
|
+
signature = inspect.signature(registry_item.__init__)
|
|
63
|
+
else:
|
|
64
|
+
signature = inspect.signature(registry_item)
|
|
65
|
+
|
|
66
|
+
return {k: v for k, v in kwargs.items() if k in signature.parameters}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Registry for classes and methods."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List, Type, Union
|
|
4
|
+
|
|
5
|
+
RegistryItem = Union[Type[Any], Callable[..., Any]]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Registry:
|
|
9
|
+
"""A registry to store and access classes and methods by a unique key."""
|
|
10
|
+
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
"""Initializes the registry class."""
|
|
13
|
+
self._registry: Dict[str, RegistryItem] = {}
|
|
14
|
+
|
|
15
|
+
def register(self, key: str, /) -> Callable[[RegistryItem], RegistryItem]:
|
|
16
|
+
"""A decorator to register a class or method with a unique key.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
key: The key to register the class or method under.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
A decorator that registers the class or method in the registry.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def wrapper(obj: RegistryItem) -> RegistryItem:
|
|
26
|
+
if key in self.entries():
|
|
27
|
+
raise ValueError(f"Entry {key} is already registered.")
|
|
28
|
+
|
|
29
|
+
self._registry[key] = obj
|
|
30
|
+
return obj
|
|
31
|
+
|
|
32
|
+
return wrapper
|
|
33
|
+
|
|
34
|
+
def get(self, name: str) -> RegistryItem:
|
|
35
|
+
"""Gets the class or method from the registry."""
|
|
36
|
+
if name not in self._registry:
|
|
37
|
+
raise ValueError(f"Item {name} not found in the registry.")
|
|
38
|
+
return self._registry[name]
|
|
39
|
+
|
|
40
|
+
def entries(self) -> List[str]:
|
|
41
|
+
"""List all items in the registry."""
|
|
42
|
+
return list(self._registry.keys())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Utility functions related to package requirements."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
from packaging import version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_dependencies(requirements: Dict[str, str]) -> None:
|
|
10
|
+
"""Check installed package versions against requirements dict.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
requirements: A dictionary where keys are package names and
|
|
14
|
+
values are minimum required versions.
|
|
15
|
+
|
|
16
|
+
Raises:
|
|
17
|
+
ImportError: If any package does not meet the minimum required version.
|
|
18
|
+
"""
|
|
19
|
+
for package, min_version in requirements.items():
|
|
20
|
+
module = importlib.import_module(package)
|
|
21
|
+
actual = getattr(module, "__version__", None)
|
|
22
|
+
if actual and not (version.parse(actual) >= version.parse(min_version)):
|
|
23
|
+
raise ImportError(
|
|
24
|
+
f"Package '{package}' version {actual} does not meet "
|
|
25
|
+
f"the minimum required version {min_version}."
|
|
26
|
+
)
|
eva/language/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""eva language API."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from eva.language.data import datasets
|
|
5
|
+
except ImportError as e:
|
|
6
|
+
msg = (
|
|
7
|
+
"eva language requirements are not installed.\n\n"
|
|
8
|
+
"Please pip install as follows:\n"
|
|
9
|
+
' python -m pip install "kaiko-eva[language]" --upgrade'
|
|
10
|
+
)
|
|
11
|
+
raise ImportError(str(e) + "\n\n" + msg) from e
|
|
12
|
+
|
|
13
|
+
__all__ = ["datasets"]
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Base for text classification datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.language.data.datasets.language import LanguageDataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextClassification(LanguageDataset[Tuple[str, torch.Tensor, Dict[str, Any]]], abc.ABC):
|
|
13
|
+
"""Text classification abstract dataset."""
|
|
14
|
+
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
"""Initializes the text classification dataset."""
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def classes(self) -> List[str] | None:
|
|
21
|
+
"""Returns list of class names."""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def class_to_idx(self) -> Dict[str, int] | None:
|
|
25
|
+
"""Returns class name to index mapping."""
|
|
26
|
+
|
|
27
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
28
|
+
"""Returns the dataset metadata.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
index: The index of the data sample.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The sample metadata.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@abc.abstractmethod
|
|
38
|
+
def load_text(self, index: int) -> str:
|
|
39
|
+
"""Returns the text content.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
index: The index of the data sample.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The text content.
|
|
46
|
+
"""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
51
|
+
"""Returns the target label.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
index: The index of the data sample.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The target label.
|
|
58
|
+
"""
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
|
|
63
|
+
return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""PubMedQA dataset class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from typing import Dict, List, Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.language.data.datasets.classification import base
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PubMedQA(base.TextClassification):
|
|
16
|
+
"""Dataset class for PubMedQA question answering task."""
|
|
17
|
+
|
|
18
|
+
_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
|
|
19
|
+
"""Dataset license."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
root: str | None = None,
|
|
24
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
25
|
+
download: bool = False,
|
|
26
|
+
max_samples: int | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Initialize the PubMedQA dataset.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
root: Directory to cache the dataset. If None, no local caching is used.
|
|
32
|
+
split: Valid splits among ["train", "val", "test"].
|
|
33
|
+
If None, it will use "train+test+validation".
|
|
34
|
+
download: Whether to download the dataset if not found locally. Default is False.
|
|
35
|
+
max_samples: Maximum number of samples to use. If None, use all samples.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self._root = root
|
|
40
|
+
self._split = split
|
|
41
|
+
self._download = download
|
|
42
|
+
self._max_samples = max_samples
|
|
43
|
+
|
|
44
|
+
def _load_dataset(self, dataset_path: str | None) -> Dataset:
|
|
45
|
+
"""Loads the PubMedQA dataset from the local cache or downloads it.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dataset_path: The path to the local cache (may be None).
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The loaded dataset object.
|
|
52
|
+
"""
|
|
53
|
+
dataset_name = "bigbio/pubmed_qa"
|
|
54
|
+
config_name = "pubmed_qa_labeled_fold0_source"
|
|
55
|
+
split = (self._split or "train+test+validation") if self._split != "val" else "validation"
|
|
56
|
+
|
|
57
|
+
if self._download:
|
|
58
|
+
logger.info("Downloading dataset from HuggingFace Hub")
|
|
59
|
+
raw_dataset = load_dataset(
|
|
60
|
+
dataset_name,
|
|
61
|
+
name=config_name,
|
|
62
|
+
split=split,
|
|
63
|
+
trust_remote_code=True,
|
|
64
|
+
download_mode="reuse_dataset_if_exists",
|
|
65
|
+
)
|
|
66
|
+
if dataset_path:
|
|
67
|
+
raw_dataset.save_to_disk(dataset_path) # type: ignore
|
|
68
|
+
logger.info(f"Dataset saved to: {dataset_path}")
|
|
69
|
+
else:
|
|
70
|
+
if not dataset_path or not os.path.exists(dataset_path):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"Dataset path not found. Set download=True or provide a valid root path."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logger.info(f"Loading dataset from: {dataset_path}")
|
|
76
|
+
raw_dataset = load_from_disk(dataset_path)
|
|
77
|
+
|
|
78
|
+
return raw_dataset # type: ignore
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def prepare_data(self) -> None:
|
|
82
|
+
"""Downloads and prepares the PubMedQA dataset.
|
|
83
|
+
|
|
84
|
+
If `self._root` is None, the dataset is used directly from HuggingFace.
|
|
85
|
+
Otherwise, it checks if the dataset is already cached in `self._root`.
|
|
86
|
+
If not cached, it downloads the dataset into `self._root`.
|
|
87
|
+
"""
|
|
88
|
+
dataset_path = None
|
|
89
|
+
|
|
90
|
+
if self._root:
|
|
91
|
+
dataset_path = self._root
|
|
92
|
+
os.makedirs(self._root, exist_ok=True)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
self.dataset = self._load_dataset(dataset_path)
|
|
96
|
+
if self._max_samples is not None and len(self.dataset) > self._max_samples:
|
|
97
|
+
logger.info(
|
|
98
|
+
f"Subsampling dataset from {len(self.dataset)} to {self._max_samples} samples"
|
|
99
|
+
)
|
|
100
|
+
random.seed(42)
|
|
101
|
+
indices = random.sample(range(len(self.dataset)), self._max_samples)
|
|
102
|
+
self.dataset = self.dataset.select(indices)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise RuntimeError(f"Failed to prepare dataset: {e}") from e
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
@override
|
|
108
|
+
def classes(self) -> List[str]:
|
|
109
|
+
return ["no", "yes", "maybe"]
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
@override
|
|
113
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
114
|
+
return {"no": 0, "yes": 1, "maybe": 2}
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def load_text(self, index: int) -> str:
|
|
118
|
+
if index < 0 or index >= len(self.dataset):
|
|
119
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
|
|
120
|
+
sample = dict(self.dataset[index])
|
|
121
|
+
return f"Question: {sample['QUESTION']}\nContext: " + " ".join(sample["CONTEXTS"])
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
125
|
+
if index < 0 or index >= len(self.dataset):
|
|
126
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
|
|
127
|
+
return torch.tensor(
|
|
128
|
+
self.class_to_idx[self.dataset[index]["final_decision"]], dtype=torch.long
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@override
|
|
132
|
+
def load_metadata(self, index: int) -> Dict[str, str]:
|
|
133
|
+
sample = self.dataset[index]
|
|
134
|
+
return {
|
|
135
|
+
"year": sample.get("YEAR") or "",
|
|
136
|
+
"labels": sample.get("LABELS") or "",
|
|
137
|
+
"meshes": sample.get("MESHES") or "",
|
|
138
|
+
"long_answer": sample.get("LONG_ANSWER") or "",
|
|
139
|
+
"reasoning_required": sample.get("reasoning_required_pred") or "",
|
|
140
|
+
"reasoning_free": sample.get("reasoning_free_pred") or "",
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
@override
|
|
144
|
+
def __len__(self) -> int:
|
|
145
|
+
return len(self.dataset)
|
|
146
|
+
|
|
147
|
+
def _print_license(self) -> None:
|
|
148
|
+
"""Prints the dataset license."""
|
|
149
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Language Dataset base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from eva.core.data.datasets import base
|
|
7
|
+
|
|
8
|
+
DataSample = TypeVar("DataSample")
|
|
9
|
+
"""The data sample type."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
13
|
+
"""Base dataset class for text tasks."""
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Language Models API."""
|
|
2
|
+
|
|
3
|
+
from eva.language.models import modules, wrappers
|
|
4
|
+
from eva.language.models.modules import TextModule
|
|
5
|
+
from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from eva.language.models.wrappers import VLLMTextModel
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"modules",
|
|
12
|
+
"wrappers",
|
|
13
|
+
"TextModule",
|
|
14
|
+
"HuggingFaceTextModel",
|
|
15
|
+
"LiteLLMTextModel",
|
|
16
|
+
"VLLMTextModel",
|
|
17
|
+
]
|
|
18
|
+
except ImportError:
|
|
19
|
+
__all__ = [
|
|
20
|
+
"modules",
|
|
21
|
+
"wrappers",
|
|
22
|
+
"TextModule",
|
|
23
|
+
"HuggingFaceTextModel",
|
|
24
|
+
"LiteLLMTextModel",
|
|
25
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""LLM Text Module for Inference."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.metrics import structs as metrics_lib
|
|
11
|
+
from eva.core.models.modules import module
|
|
12
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
13
|
+
from eva.language.models.modules.typings import TEXT_BATCH
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TextModule(module.ModelModule):
|
|
17
|
+
"""Text-based LLM module for inference.
|
|
18
|
+
|
|
19
|
+
Uses LLM wrappers for text generation and supports evaluation using
|
|
20
|
+
configurable metrics and post-processing transforms.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model: nn.Module,
|
|
26
|
+
prompt: str,
|
|
27
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
28
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the text inference module.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: An LLM wrapper (PyTorch-compatible) for text generation.
|
|
34
|
+
prompt: The prompt to use for generating text.
|
|
35
|
+
metrics: Metrics schema for evaluation.
|
|
36
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
39
|
+
|
|
40
|
+
self.model = model
|
|
41
|
+
self.prompt = prompt
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def forward(self, prompts: List[str], *args: Any, **kwargs: Any) -> List[str]:
|
|
45
|
+
"""Generates text responses for a batch of prompts.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
prompts: List of input texts to generate responses.
|
|
49
|
+
args: Additional arguments.
|
|
50
|
+
kwargs: Additional keyword arguments.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
List of generated responses.
|
|
54
|
+
"""
|
|
55
|
+
return self.model(prompts)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def validation_step(self, batch: TEXT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
59
|
+
"""Validation step that runs batch inference and evaluates metrics.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
batch: An input batch.
|
|
63
|
+
args: Additional arguments.
|
|
64
|
+
kwargs: Additional keyword arguments.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Dictionary with predictions, ground truth, and evaluation metrics.
|
|
68
|
+
"""
|
|
69
|
+
return self._batch_step(batch)
|
|
70
|
+
|
|
71
|
+
def _batch_step(self, batch: TEXT_BATCH) -> STEP_OUTPUT:
|
|
72
|
+
"""Runs inference on a batch and evaluates model predictions.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
batch: Input batch containing data, targets, and metadata.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Dictionary with predictions, ground truth, and evaluation metrics.
|
|
79
|
+
"""
|
|
80
|
+
data, targets, metadata = batch
|
|
81
|
+
messages = [str(d) + "\n" + self.prompt for d in data]
|
|
82
|
+
predictions = self(messages)
|
|
83
|
+
logger.debug(f"Predictions: {predictions}")
|
|
84
|
+
logger.debug(f"Targets: {targets}")
|
|
85
|
+
return {"predictions": predictions, "targets": targets, "metadata": metadata}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Type annotations for language model modules."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, NamedTuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TEXT_BATCH(NamedTuple):
|
|
7
|
+
"""Text-based input batch data scheme."""
|
|
8
|
+
|
|
9
|
+
data: List[str]
|
|
10
|
+
"""The text data batch."""
|
|
11
|
+
|
|
12
|
+
targets: List[str] | None = None
|
|
13
|
+
"""The target text batch."""
|
|
14
|
+
|
|
15
|
+
metadata: Dict[str, Any] | None = None
|
|
16
|
+
"""The associated metadata."""
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Language Model Wrappers API."""
|
|
2
|
+
|
|
3
|
+
from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
|
|
4
|
+
from eva.language.models.wrappers.litellm import LiteLLMTextModel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from eva.language.models.wrappers.vllm import VLLMTextModel
|
|
8
|
+
|
|
9
|
+
__all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
|
|
10
|
+
except ImportError:
|
|
11
|
+
__all__ = ["HuggingFaceTextModel", "LiteLLMTextModel"]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""LLM wrapper for HuggingFace `transformers` models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Literal
|
|
4
|
+
|
|
5
|
+
from transformers.pipelines import pipeline
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HuggingFaceTextModel(base.BaseModel[List[str], List[str]]):
|
|
12
|
+
"""Wrapper class for loading HuggingFace `transformers` models using pipelines."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model_name_or_path: str,
|
|
17
|
+
task: Literal["text-generation"] = "text-generation",
|
|
18
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
19
|
+
generation_kwargs: Dict[str, Any] | None = None,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Initializes the model.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
model_name_or_path: The model name or path to load the model from.
|
|
25
|
+
This can be a local path or a model name from the `HuggingFace`
|
|
26
|
+
model hub.
|
|
27
|
+
task: The pipeline task. Defaults to "text-generation".
|
|
28
|
+
model_kwargs: Additional arguments for configuring the pipeline.
|
|
29
|
+
generation_kwargs: Additional generation parameters (temperature, max_length, etc.).
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self._model_name_or_path = model_name_or_path
|
|
34
|
+
self._task = task
|
|
35
|
+
self._model_kwargs = model_kwargs or {}
|
|
36
|
+
self._generation_kwargs = generation_kwargs or {}
|
|
37
|
+
|
|
38
|
+
self.load_model()
|
|
39
|
+
|
|
40
|
+
@override
|
|
41
|
+
def load_model(self) -> None:
|
|
42
|
+
"""Loads the model as a Hugging Face pipeline."""
|
|
43
|
+
self._pipeline = pipeline(
|
|
44
|
+
task=self._task,
|
|
45
|
+
model=self._model_name_or_path,
|
|
46
|
+
trust_remote_code=True,
|
|
47
|
+
**self._model_kwargs,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def model_forward(self, prompts: List[str]) -> List[str]:
|
|
52
|
+
"""Generates text using the pipeline.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
prompts: The input prompts for the model.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
The generated text as a string.
|
|
59
|
+
"""
|
|
60
|
+
outputs = self._pipeline(prompts, return_full_text=False, **self._generation_kwargs)
|
|
61
|
+
if outputs is None:
|
|
62
|
+
raise ValueError("Outputs from the model are None.")
|
|
63
|
+
results = []
|
|
64
|
+
for output in outputs:
|
|
65
|
+
if isinstance(output, list):
|
|
66
|
+
results.append(output[0]["generated_text"]) # type: ignore
|
|
67
|
+
else:
|
|
68
|
+
results.append(output["generated_text"]) # type: ignore
|
|
69
|
+
return results
|