kaiko-eva 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +95 -38
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,11 +10,20 @@ from loguru import logger
|
|
|
10
10
|
from typing_extensions import override
|
|
11
11
|
|
|
12
12
|
from eva.language.data.datasets.classification import base
|
|
13
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class PubMedQA(base.TextClassification):
|
|
16
17
|
"""Dataset class for PubMedQA question answering task."""
|
|
17
18
|
|
|
19
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
20
|
+
"train": 450,
|
|
21
|
+
"val": 50,
|
|
22
|
+
"test": 500,
|
|
23
|
+
None: 500,
|
|
24
|
+
}
|
|
25
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
26
|
+
|
|
18
27
|
_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
|
|
19
28
|
"""Dataset license."""
|
|
20
29
|
|
|
@@ -52,7 +61,14 @@ class PubMedQA(base.TextClassification):
|
|
|
52
61
|
"""
|
|
53
62
|
dataset_name = "bigbio/pubmed_qa"
|
|
54
63
|
config_name = "pubmed_qa_labeled_fold0_source"
|
|
55
|
-
|
|
64
|
+
|
|
65
|
+
match self._split:
|
|
66
|
+
case "val":
|
|
67
|
+
split = "validation"
|
|
68
|
+
case None:
|
|
69
|
+
split = "train+test+validation"
|
|
70
|
+
case _:
|
|
71
|
+
split = self._split
|
|
56
72
|
|
|
57
73
|
if self._download:
|
|
58
74
|
logger.info("Downloading dataset from HuggingFace Hub")
|
|
@@ -88,7 +104,7 @@ class PubMedQA(base.TextClassification):
|
|
|
88
104
|
dataset_path = None
|
|
89
105
|
|
|
90
106
|
if self._root:
|
|
91
|
-
dataset_path = self._root
|
|
107
|
+
dataset_path = os.path.join(self._root, self._split) if self._split else self._root
|
|
92
108
|
os.makedirs(self._root, exist_ok=True)
|
|
93
109
|
|
|
94
110
|
try:
|
|
@@ -103,6 +119,15 @@ class PubMedQA(base.TextClassification):
|
|
|
103
119
|
except Exception as e:
|
|
104
120
|
raise RuntimeError(f"Failed to prepare dataset: {e}") from e
|
|
105
121
|
|
|
122
|
+
@override
|
|
123
|
+
def validate(self) -> None:
|
|
124
|
+
if len(self) != self._expected_dataset_lengths[self._split]:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Dataset length mismatch for split '{self._split}': "
|
|
127
|
+
f"expected {self._expected_dataset_lengths[self._split]}, "
|
|
128
|
+
f"but got {len(self)}"
|
|
129
|
+
)
|
|
130
|
+
|
|
106
131
|
@property
|
|
107
132
|
@override
|
|
108
133
|
def classes(self) -> List[str]:
|
|
@@ -114,11 +139,18 @@ class PubMedQA(base.TextClassification):
|
|
|
114
139
|
return {"no": 0, "yes": 1, "maybe": 2}
|
|
115
140
|
|
|
116
141
|
@override
|
|
117
|
-
def load_text(self, index: int) ->
|
|
142
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
118
143
|
if index < 0 or index >= len(self.dataset):
|
|
119
144
|
raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
|
|
120
145
|
sample = dict(self.dataset[index])
|
|
121
|
-
return
|
|
146
|
+
return [
|
|
147
|
+
UserMessage(
|
|
148
|
+
content=f"Question: {sample['QUESTION']}\nContext: "
|
|
149
|
+
+ " ".join(sample["CONTEXTS"])
|
|
150
|
+
+ "\nInstruction: Carefully read the question and the provided context. "
|
|
151
|
+
+ "Answer with one word: 'yes', 'no', or 'maybe'. Answer: "
|
|
152
|
+
)
|
|
153
|
+
]
|
|
122
154
|
|
|
123
155
|
@override
|
|
124
156
|
def load_target(self, index: int) -> torch.Tensor:
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Dataset class for loading pre-generated text predictions."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Generic, Literal
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
11
|
+
from eva.language.data.datasets.schemas import TransformsSchema
|
|
12
|
+
from eva.language.data.datasets.typings import PredictionSample, TargetType
|
|
13
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
14
|
+
from eva.language.utils.text import messages as message_utils
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TextPredictionDataset(
|
|
18
|
+
LanguageDataset[PredictionSample[TargetType]], abc.ABC, Generic[TargetType]
|
|
19
|
+
):
|
|
20
|
+
"""Dataset class for loading pre-generated text predictions."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
path: str,
|
|
25
|
+
prediction_column: str = "prediction",
|
|
26
|
+
target_column: str = "target",
|
|
27
|
+
text_column: str | None = None,
|
|
28
|
+
metadata_columns: list[str] | None = None,
|
|
29
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
30
|
+
transforms: TransformsSchema | None = None,
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the dataset.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
path: The path to the manifest file holding the predictions & targets.
|
|
36
|
+
prediction_column: The name of the prediction column.
|
|
37
|
+
target_column: The name of the label column.
|
|
38
|
+
text_column: The name of the column with the text inputs that were used
|
|
39
|
+
to generate the predictions. If the text column contains chat message
|
|
40
|
+
json format ([{"role": ..., "content": ...}]), it will be deserialized into
|
|
41
|
+
a list of Message objects. Otherwise, the content is interpreted as a
|
|
42
|
+
single user message.
|
|
43
|
+
metadata_columns: List of column names to include in metadata.
|
|
44
|
+
split: The dataset split to use (train, val, test). If not specified,
|
|
45
|
+
the entire dataset will be used.
|
|
46
|
+
transforms: The transforms to apply to the text and target when
|
|
47
|
+
loading the samples.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self.path = path
|
|
52
|
+
self.prediction_column = prediction_column
|
|
53
|
+
self.target_column = target_column
|
|
54
|
+
self.text_column = text_column
|
|
55
|
+
self.metadata_columns = metadata_columns
|
|
56
|
+
self.split = split
|
|
57
|
+
self.transforms = transforms
|
|
58
|
+
|
|
59
|
+
self._data: pd.DataFrame
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def __len__(self) -> int:
|
|
63
|
+
return len(self._data)
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def __getitem__(self, index: int) -> PredictionSample[TargetType]:
|
|
67
|
+
item = PredictionSample(
|
|
68
|
+
prediction=self.load_prediction(index),
|
|
69
|
+
target=self.load_target(index),
|
|
70
|
+
text=self.load_text(index),
|
|
71
|
+
metadata=self.load_metadata(index) or {},
|
|
72
|
+
)
|
|
73
|
+
return self._apply_transforms(item)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def configure(self) -> None:
|
|
77
|
+
extension = Path(self.path).suffix
|
|
78
|
+
|
|
79
|
+
match extension:
|
|
80
|
+
case ".jsonl":
|
|
81
|
+
self._data = pd.read_json(self.path, lines=True)
|
|
82
|
+
case ".csv":
|
|
83
|
+
self._data = pd.read_csv(self.path)
|
|
84
|
+
case ".parquet":
|
|
85
|
+
self._data = pd.read_parquet(self.path)
|
|
86
|
+
case _:
|
|
87
|
+
raise ValueError(f"Unsupported file extension: {extension}")
|
|
88
|
+
|
|
89
|
+
if self.split is not None:
|
|
90
|
+
self._data = self._data[self._data["split"] == self.split].reset_index(drop=True) # type: ignore
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def validate(self) -> None:
|
|
94
|
+
if self.prediction_column not in self._data.columns:
|
|
95
|
+
raise ValueError(f"Label column '{self.prediction_column}' not found.")
|
|
96
|
+
if self.target_column not in self._data.columns:
|
|
97
|
+
raise ValueError(f"Label column '{self.target_column}' not found.")
|
|
98
|
+
if self.metadata_columns:
|
|
99
|
+
missing_columns = set(self.metadata_columns) - set(self._data.columns)
|
|
100
|
+
if missing_columns:
|
|
101
|
+
raise ValueError(f"Metadata columns {missing_columns} not found.")
|
|
102
|
+
|
|
103
|
+
def load_prediction(self, index: int) -> TargetType:
|
|
104
|
+
"""Returns the prediction for the given index."""
|
|
105
|
+
return self._data.iloc[index][self.prediction_column]
|
|
106
|
+
|
|
107
|
+
def load_target(self, index: int) -> TargetType:
|
|
108
|
+
"""Returns the target for the given index."""
|
|
109
|
+
return self._data.iloc[index][self.target_column]
|
|
110
|
+
|
|
111
|
+
def load_text(self, index: int) -> MessageSeries | None:
|
|
112
|
+
"""Returns the text for the given index."""
|
|
113
|
+
if self.text_column is None:
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
text = self._data.iloc[index][self.text_column]
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
return message_utils.deserialize(self._data.iloc[index][self.text_column])
|
|
120
|
+
except Exception:
|
|
121
|
+
return [UserMessage(content=text)]
|
|
122
|
+
|
|
123
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
124
|
+
"""Returns the metadata for the given index."""
|
|
125
|
+
if self.metadata_columns is None:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
row = self._data.iloc[index]
|
|
129
|
+
return {col: row[col] for col in self.metadata_columns}
|
|
130
|
+
|
|
131
|
+
def _apply_transforms(
|
|
132
|
+
self, sample: PredictionSample[TargetType]
|
|
133
|
+
) -> PredictionSample[TargetType]:
|
|
134
|
+
"""Applies the dataset transforms to the prediction and target."""
|
|
135
|
+
if self.transforms:
|
|
136
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
137
|
+
prediction = (
|
|
138
|
+
self.transforms.prediction(sample.prediction)
|
|
139
|
+
if self.transforms.prediction
|
|
140
|
+
else sample.prediction
|
|
141
|
+
)
|
|
142
|
+
target = (
|
|
143
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
144
|
+
)
|
|
145
|
+
return PredictionSample(
|
|
146
|
+
prediction=prediction,
|
|
147
|
+
target=target,
|
|
148
|
+
text=text,
|
|
149
|
+
metadata=sample.metadata,
|
|
150
|
+
)
|
|
151
|
+
return sample
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Schema definitions for dataset classes."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass(frozen=True)
|
|
8
|
+
class TransformsSchema:
|
|
9
|
+
"""Schema for dataset transforms."""
|
|
10
|
+
|
|
11
|
+
text: Callable | None = None
|
|
12
|
+
"""Text transformation"""
|
|
13
|
+
|
|
14
|
+
target: Callable | None = None
|
|
15
|
+
"""Target transformation"""
|
|
16
|
+
|
|
17
|
+
prediction: Callable | None = None
|
|
18
|
+
"""Prediction transformation"""
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Base classes for text-image datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Dict, Generic
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
9
|
+
from eva.language.data.datasets.schemas import TransformsSchema
|
|
10
|
+
from eva.language.data.datasets.typings import TargetType, TextSample
|
|
11
|
+
from eva.language.data.messages import MessageSeries
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextDataset(LanguageDataset[TextSample[TargetType]], abc.ABC, Generic[TargetType]):
|
|
15
|
+
"""Base dataset class for text-based tasks."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
|
|
18
|
+
"""Initializes the dataset.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
*args: Positional arguments for the base class.
|
|
22
|
+
transforms: The transforms to apply to the text and target when
|
|
23
|
+
loading the samples.
|
|
24
|
+
**kwargs: Keyword arguments for the base class.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
self.transforms = transforms
|
|
29
|
+
|
|
30
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
31
|
+
"""Returns the dataset metadata.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
index: The index of the data sample.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The sample metadata.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
42
|
+
"""Returns the text content.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
index: The index of the data sample.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The text content.
|
|
49
|
+
"""
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def load_target(self, index: int) -> TargetType:
|
|
54
|
+
"""Returns the target label.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
index: The index of the data sample.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The target label.
|
|
61
|
+
"""
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def __getitem__(self, index: int) -> TextSample[TargetType]:
|
|
66
|
+
item = TextSample(
|
|
67
|
+
text=self.load_text(index),
|
|
68
|
+
target=self.load_target(index),
|
|
69
|
+
metadata=self.load_metadata(index) or {},
|
|
70
|
+
)
|
|
71
|
+
return self._apply_transforms(item)
|
|
72
|
+
|
|
73
|
+
def _apply_transforms(self, sample: TextSample[TargetType]) -> TextSample[TargetType]:
|
|
74
|
+
"""Applies the dataset transforms to the text and target.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
sample: The text sample..
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The transformed sample.
|
|
81
|
+
"""
|
|
82
|
+
if self.transforms:
|
|
83
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
84
|
+
target = (
|
|
85
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
86
|
+
)
|
|
87
|
+
return TextSample(
|
|
88
|
+
text=text,
|
|
89
|
+
target=target,
|
|
90
|
+
metadata=sample.metadata,
|
|
91
|
+
)
|
|
92
|
+
return sample
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Typings for multimodal datasets."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from typing_extensions import NamedTuple
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import MessageSeries
|
|
8
|
+
|
|
9
|
+
TargetType = TypeVar("TargetType")
|
|
10
|
+
"""The target data type."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextSample(NamedTuple, Generic[TargetType]):
|
|
14
|
+
"""Text sample with target and metadata."""
|
|
15
|
+
|
|
16
|
+
text: MessageSeries
|
|
17
|
+
"""One or multiple conversation messages."""
|
|
18
|
+
|
|
19
|
+
target: TargetType | None
|
|
20
|
+
"""Target data."""
|
|
21
|
+
|
|
22
|
+
metadata: dict[str, Any] | None
|
|
23
|
+
"""Additional metadata."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PredictionSample(NamedTuple, Generic[TargetType]):
|
|
27
|
+
"""Text sample with target and metadata."""
|
|
28
|
+
|
|
29
|
+
prediction: TargetType
|
|
30
|
+
"""Prediction data."""
|
|
31
|
+
|
|
32
|
+
target: TargetType
|
|
33
|
+
"""Target data."""
|
|
34
|
+
|
|
35
|
+
text: MessageSeries | None
|
|
36
|
+
"""Conversation messages that were used as input."""
|
|
37
|
+
|
|
38
|
+
metadata: dict[str, Any] | None
|
|
39
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Types and classes for conversation messages in a multimodal context."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import enum
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Role(str, enum.Enum):
|
|
9
|
+
"""Roles for messages in a conversation."""
|
|
10
|
+
|
|
11
|
+
USER = "user"
|
|
12
|
+
ASSISTANT = "assistant"
|
|
13
|
+
SYSTEM = "system"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class Message:
|
|
18
|
+
"""Base class for a message in a conversation."""
|
|
19
|
+
|
|
20
|
+
content: str
|
|
21
|
+
role: str
|
|
22
|
+
|
|
23
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
24
|
+
"""Convert the message to a dictionary."""
|
|
25
|
+
return dataclasses.asdict(self)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclasses.dataclass
|
|
29
|
+
class UserMessage(Message):
|
|
30
|
+
"""User message in a conversation."""
|
|
31
|
+
|
|
32
|
+
role: str = Role.USER
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclasses.dataclass
|
|
36
|
+
class AssistantMessage(Message):
|
|
37
|
+
"""Assistant message in a conversation."""
|
|
38
|
+
|
|
39
|
+
role: str = Role.ASSISTANT
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclasses.dataclass
|
|
43
|
+
class SystemMessage(Message):
|
|
44
|
+
"""System message in a conversation."""
|
|
45
|
+
|
|
46
|
+
role: str = Role.SYSTEM
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass
|
|
50
|
+
class ModelSystemMessage(SystemMessage):
|
|
51
|
+
"""System message for model-specific instructions."""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class TaskSystemMessage(SystemMessage):
|
|
56
|
+
"""System message for task-specific instructions."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
MessageSeries = List[Message]
|
|
60
|
+
"""A series of conversation messages, can contain a mix of system, user, and AI messages."""
|
eva/language/models/__init__.py
CHANGED
|
@@ -1,25 +1,29 @@
|
|
|
1
1
|
"""Language Models API."""
|
|
2
2
|
|
|
3
|
-
from eva.language.models import modules, wrappers
|
|
4
|
-
from eva.language.models.modules import
|
|
5
|
-
from eva.language.models.wrappers import
|
|
3
|
+
from eva.language.models import modules, networks, wrappers
|
|
4
|
+
from eva.language.models.modules import LanguageModule, OfflineLanguageModule
|
|
5
|
+
from eva.language.models.wrappers import HuggingFaceModel, LiteLLMModel
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
|
-
from eva.language.models.wrappers import
|
|
8
|
+
from eva.language.models.wrappers import VllmModel
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"modules",
|
|
12
12
|
"wrappers",
|
|
13
|
-
"
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
13
|
+
"networks",
|
|
14
|
+
"HuggingFaceModel",
|
|
15
|
+
"LiteLLMModel",
|
|
16
|
+
"VllmModel",
|
|
17
|
+
"LanguageModule",
|
|
18
|
+
"OfflineLanguageModule",
|
|
17
19
|
]
|
|
18
20
|
except ImportError:
|
|
19
21
|
__all__ = [
|
|
20
22
|
"modules",
|
|
21
23
|
"wrappers",
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
"
|
|
24
|
+
"networks",
|
|
25
|
+
"HuggingFaceModel",
|
|
26
|
+
"LiteLLMModel",
|
|
27
|
+
"LanguageModule",
|
|
28
|
+
"OfflineLanguageModule",
|
|
25
29
|
]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Model module for language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.metrics import structs as metrics_lib
|
|
10
|
+
from eva.core.models.modules import module
|
|
11
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.language.models.typings import PredictionBatch, TextBatch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LanguageModule(module.ModelModule):
|
|
16
|
+
"""Model module for language tasks."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
22
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the text inference module.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model instance to use for forward pass.
|
|
28
|
+
metrics: Metrics schema for evaluation.
|
|
29
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) -> List[str]:
|
|
37
|
+
return self.model(batch)
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def validation_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
41
|
+
return self._batch_step(batch)
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def test_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
45
|
+
return self._batch_step(batch)
|
|
46
|
+
|
|
47
|
+
def _batch_step(self, batch: TextBatch) -> STEP_OUTPUT:
|
|
48
|
+
text, targets, metadata = TextBatch(*batch)
|
|
49
|
+
predictions = self.forward(batch)
|
|
50
|
+
return {
|
|
51
|
+
"inputs": text,
|
|
52
|
+
"predictions": predictions,
|
|
53
|
+
"targets": targets,
|
|
54
|
+
"metadata": metadata,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class OfflineLanguageModule(module.ModelModule):
|
|
59
|
+
"""Model module for offline language tasks."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
64
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Initializes the text inference module.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
metrics: Metrics schema for evaluation.
|
|
70
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def forward(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> PredictionBatch:
|
|
76
|
+
return batch
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def validation_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
80
|
+
return self._batch_step(batch)
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def test_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
84
|
+
return self._batch_step(batch)
|
|
85
|
+
|
|
86
|
+
def _batch_step(self, batch: PredictionBatch) -> STEP_OUTPUT:
|
|
87
|
+
predictions, targets, text, metadata = PredictionBatch(*batch)
|
|
88
|
+
return {
|
|
89
|
+
"inputs": text,
|
|
90
|
+
"predictions": predictions,
|
|
91
|
+
"targets": targets,
|
|
92
|
+
"metadata": metadata,
|
|
93
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Language networks API."""
|
|
2
|
+
|
|
3
|
+
from eva.language.models.networks.alibaba import Qwen205BInstruct
|
|
4
|
+
from eva.language.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
|
|
5
|
+
from eva.language.models.networks.registry import model_registry
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Claude35Sonnet20240620",
|
|
9
|
+
"Claude37Sonnet20250219",
|
|
10
|
+
"Qwen205BInstruct",
|
|
11
|
+
"model_registry",
|
|
12
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-0-5b-instruct")
|
|
10
|
+
class Qwen205BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2 0.5B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
|
|
14
|
+
"""Initialize the model."""
|
|
15
|
+
super().__init__(
|
|
16
|
+
model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
|
|
17
|
+
model_kwargs={
|
|
18
|
+
"torch_dtype": torch.bfloat16,
|
|
19
|
+
"cache_dir": cache_dir,
|
|
20
|
+
},
|
|
21
|
+
generation_kwargs={
|
|
22
|
+
"max_new_tokens": 512,
|
|
23
|
+
},
|
|
24
|
+
system_prompt=system_prompt,
|
|
25
|
+
chat_mode=True,
|
|
26
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|