kaiko-eva 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +95 -38
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Message formatting utilities for language models."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import (
|
|
8
|
+
AssistantMessage,
|
|
9
|
+
MessageSeries,
|
|
10
|
+
Role,
|
|
11
|
+
SystemMessage,
|
|
12
|
+
UserMessage,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def format_chat_message(message: MessageSeries) -> List[Dict[str, Any]]:
|
|
17
|
+
"""Formats a message series into a format following OpenAI's API specification."""
|
|
18
|
+
return [{"role": item.role, "content": item.content} for item in message]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def combine_system_messages(message: MessageSeries, join_char: str = "\n") -> MessageSeries:
|
|
22
|
+
"""Combine system messages into a single message.
|
|
23
|
+
|
|
24
|
+
This is useful when the MessageSeries contains multiple system messages such
|
|
25
|
+
as `ModelSystemMessage` and `TaskSystemMessage`. But given that most models / apis
|
|
26
|
+
expect a single system message, this function can be used to combines them into one.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
message: The message series containing one or multiple messages.
|
|
30
|
+
join_char: The character to use to join the system messages. Default is newline.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A new message series with system messages combined into one and the
|
|
34
|
+
remaining messages unchanged.
|
|
35
|
+
"""
|
|
36
|
+
system_messages = list(filter(lambda item: item.role == Role.SYSTEM, message))
|
|
37
|
+
if len(system_messages) == 0:
|
|
38
|
+
return message
|
|
39
|
+
|
|
40
|
+
non_system_messages = list(filter(lambda item: item.role != Role.SYSTEM, message))
|
|
41
|
+
return [
|
|
42
|
+
SystemMessage(content=merge_message_contents(system_messages, join_char=join_char))
|
|
43
|
+
] + non_system_messages
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def merge_message_contents(message: MessageSeries, join_char: str = "\n") -> str:
|
|
47
|
+
"""Merges the all contents within a message series into a string.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
message: The message series to combine.
|
|
51
|
+
join_char: The character to use to join the message contents. Default is newline.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A string containing the combined message contents.
|
|
55
|
+
"""
|
|
56
|
+
return join_char.join(item.content for item in message)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def insert_system_message(
|
|
60
|
+
message: MessageSeries, system_message: SystemMessage | None
|
|
61
|
+
) -> MessageSeries:
|
|
62
|
+
"""Insert a system message at the beginning of the message series."""
|
|
63
|
+
if system_message is None:
|
|
64
|
+
return message
|
|
65
|
+
return [system_message] + message
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def batch_insert_system_message(
|
|
69
|
+
messages: List[MessageSeries], system_message: SystemMessage | None
|
|
70
|
+
) -> List[MessageSeries]:
|
|
71
|
+
"""Insert a system message at the beginning of each message series in a batch."""
|
|
72
|
+
return list(
|
|
73
|
+
map(functools.partial(insert_system_message, system_message=system_message), messages)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def serialize(messages: MessageSeries) -> str:
|
|
78
|
+
"""Serialize a MessageSeries object into a JSON string.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: A list of message objects (MessagesSeries).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A JSON string representing the message series, with the following format:
|
|
85
|
+
[{"role": "user", "content": "Hello"}, ...]
|
|
86
|
+
"""
|
|
87
|
+
serialized_messages = format_chat_message(messages)
|
|
88
|
+
return json.dumps(serialized_messages)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def deserialize(messages: str) -> MessageSeries:
|
|
92
|
+
"""Convert a json string to a MessageSeries object.
|
|
93
|
+
|
|
94
|
+
Format: [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
|
|
95
|
+
"""
|
|
96
|
+
message_dicts = json.loads(messages)
|
|
97
|
+
|
|
98
|
+
message_series = []
|
|
99
|
+
for message_dict in message_dicts:
|
|
100
|
+
if "role" not in message_dict or "content" not in message_dict:
|
|
101
|
+
raise ValueError("`role` or `content` keys are missing.")
|
|
102
|
+
|
|
103
|
+
match message_dict["role"]:
|
|
104
|
+
case Role.USER:
|
|
105
|
+
message_series.append(UserMessage(**message_dict))
|
|
106
|
+
case Role.ASSISTANT:
|
|
107
|
+
message_series.append(AssistantMessage(**message_dict))
|
|
108
|
+
case Role.SYSTEM:
|
|
109
|
+
message_series.append(SystemMessage(**message_dict))
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"Unknown role: {message_dict['role']}")
|
|
112
|
+
|
|
113
|
+
return message_series
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Text prediction writer callbacks."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.callbacks import writers
|
|
9
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextPredictionWriter(writers.TextPredictionWriter):
|
|
13
|
+
"""Callback for writing generated text predictions to disk."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
output_dir: str,
|
|
18
|
+
model: nn.Module,
|
|
19
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
20
|
+
metadata_keys: List[str] | None = None,
|
|
21
|
+
include_input: bool = True,
|
|
22
|
+
overwrite: bool = False,
|
|
23
|
+
save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
|
|
24
|
+
) -> None:
|
|
25
|
+
"""See docstring of base class."""
|
|
26
|
+
super().__init__(
|
|
27
|
+
output_dir=output_dir,
|
|
28
|
+
model=model,
|
|
29
|
+
dataloader_idx_map=dataloader_idx_map,
|
|
30
|
+
metadata_keys=metadata_keys,
|
|
31
|
+
include_input=include_input,
|
|
32
|
+
overwrite=overwrite,
|
|
33
|
+
save_format=save_format,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def _unpack_batch(self, batch: TextImageBatch) -> Tuple[list, list | None, dict | None]: # type: ignore
|
|
38
|
+
text_batch, _, target_batch, metadata_batch = TextImageBatch(*batch)
|
|
39
|
+
return text_batch, target_batch, metadata_batch
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Collate functions for text-image data."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from torch.utils.data._utils.collate import default_collate
|
|
6
|
+
|
|
7
|
+
from eva.multimodal.data.datasets.typings import TextImageSample
|
|
8
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def text_image_collate(batch: List[TextImageSample]) -> TextImageBatch:
|
|
12
|
+
"""Collate function for text-image batches."""
|
|
13
|
+
texts, images, targets, metadata = zip(*batch, strict=False)
|
|
14
|
+
|
|
15
|
+
first_sample = batch[0]
|
|
16
|
+
metadata = None
|
|
17
|
+
if first_sample.metadata is not None:
|
|
18
|
+
metadata = {
|
|
19
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
20
|
+
for k in first_sample.metadata.keys()
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
return TextImageBatch(
|
|
24
|
+
text=list(texts),
|
|
25
|
+
image=list(images),
|
|
26
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
27
|
+
metadata=metadata,
|
|
28
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Multimodal Dataset base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from eva.core.data.datasets import base
|
|
7
|
+
|
|
8
|
+
DataSample = TypeVar("DataSample")
|
|
9
|
+
"""The data sample type."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultimodalDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
13
|
+
"""Base dataset class for multimodal tasks."""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""PatchCamelyon dataset with text prompts for multimodal classification."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Literal
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
9
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
10
|
+
from eva.multimodal.data.datasets.text_image import TextImageDataset
|
|
11
|
+
from eva.vision.data import datasets as vision_datasets
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PatchCamelyon(TextImageDataset[int], vision_datasets.PatchCamelyon):
|
|
15
|
+
"""PatchCamelyon image classification using a multiple choice text prompt."""
|
|
16
|
+
|
|
17
|
+
_default_prompt = (
|
|
18
|
+
"You are a pathology expert helping pathologists to analyze images of tissue samples.\n"
|
|
19
|
+
"Question: Does this image show metastatic breast tissue?\n"
|
|
20
|
+
"Options: A: no, B: yes\n"
|
|
21
|
+
"Only answer with a single letter without further explanation. "
|
|
22
|
+
"Please always provide an answer, even if you are not sure.\n"
|
|
23
|
+
"Answer: "
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
root: str,
|
|
29
|
+
split: Literal["train", "val", "test"],
|
|
30
|
+
download: bool = False,
|
|
31
|
+
transforms: TransformsSchema | None = None,
|
|
32
|
+
prompt: str | None = None,
|
|
33
|
+
max_samples: int | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initializes the dataset.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
root: The path to the dataset root. This path should contain
|
|
39
|
+
the uncompressed h5 files and the metadata.
|
|
40
|
+
split: The dataset split for training, validation, or testing.
|
|
41
|
+
download: Whether to download the data for the specified split.
|
|
42
|
+
Note that the download will be executed only by additionally
|
|
43
|
+
calling the :meth:`prepare_data` method.
|
|
44
|
+
transforms: A function/transform which returns a transformed
|
|
45
|
+
version of the raw data samples.
|
|
46
|
+
prompt: The text prompt to use for classification (multple choice).
|
|
47
|
+
max_samples: Maximum number of samples to use. If None, use all samples.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(root=root, split=split, download=download, transforms=transforms)
|
|
50
|
+
|
|
51
|
+
self.max_samples = max_samples
|
|
52
|
+
self.prompt = prompt or self._default_prompt
|
|
53
|
+
|
|
54
|
+
if self.max_samples is not None:
|
|
55
|
+
self._expected_length = {split: max_samples}
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
@override
|
|
59
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
60
|
+
return {"A": 0, "B": 1}
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def __len__(self) -> int:
|
|
64
|
+
return self.max_samples or self._fetch_dataset_length()
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
68
|
+
return [UserMessage(content=self.prompt)]
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
72
|
+
return vision_datasets.PatchCamelyon.load_data(self, index)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def load_target(self, index: int) -> int:
|
|
76
|
+
return int(vision_datasets.PatchCamelyon.load_target(self, index).item())
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
80
|
+
return vision_datasets.PatchCamelyon.load_metadata(self, index)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Schema definitions for dataset classes."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from eva.language.data.datasets import schemas as language_schemas
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclasses.dataclass(frozen=True)
|
|
10
|
+
class TransformsSchema(language_schemas.TransformsSchema):
|
|
11
|
+
"""Schema for dataset transforms."""
|
|
12
|
+
|
|
13
|
+
image: Callable | None = None
|
|
14
|
+
"""Image transformation"""
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Base classes for text-image datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic
|
|
5
|
+
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.language.data.datasets.text import TextDataset
|
|
10
|
+
from eva.multimodal.data.datasets.base import MultimodalDataset
|
|
11
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
12
|
+
from eva.multimodal.data.datasets.typings import TargetType, TextImageSample
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TextImageDataset(
|
|
16
|
+
MultimodalDataset[TextImageSample[TargetType]], TextDataset, abc.ABC, Generic[TargetType]
|
|
17
|
+
):
|
|
18
|
+
"""Base dataset class for text-image tasks."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
|
|
21
|
+
"""Initializes the dataset.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
*args: Positional arguments for the base class.
|
|
25
|
+
transforms: The transforms to apply to the text, image and target when
|
|
26
|
+
loading the samples.
|
|
27
|
+
**kwargs: Keyword arguments for the base class.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__(*args, **kwargs)
|
|
30
|
+
|
|
31
|
+
self.transforms = transforms
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
35
|
+
"""Returns the image content.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
index: The index of the data sample.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The image content.
|
|
42
|
+
"""
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def __getitem__(self, index: int) -> TextImageSample[TargetType]:
|
|
47
|
+
item = TextImageSample(
|
|
48
|
+
text=self.load_text(index),
|
|
49
|
+
image=self.load_image(index),
|
|
50
|
+
target=self.load_target(index),
|
|
51
|
+
metadata=self.load_metadata(index) or {},
|
|
52
|
+
)
|
|
53
|
+
return self._apply_transforms(item)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def _apply_transforms(self, sample: TextImageSample[TargetType]) -> TextImageSample[TargetType]:
|
|
57
|
+
"""Applies the dataset transforms to the text, image and target.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
sample: The sample containing text, image, target and metadata.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The transformed sample.
|
|
64
|
+
"""
|
|
65
|
+
if self.transforms:
|
|
66
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
67
|
+
image = self.transforms.image(sample.image) if self.transforms.image else sample.image
|
|
68
|
+
target = (
|
|
69
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
70
|
+
)
|
|
71
|
+
return TextImageSample(
|
|
72
|
+
text=text,
|
|
73
|
+
image=image,
|
|
74
|
+
target=target,
|
|
75
|
+
metadata=sample.metadata,
|
|
76
|
+
)
|
|
77
|
+
return sample
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Typings for multimodal datasets."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import NamedTuple
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextImageSample(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text and image sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: MessageSeries
|
|
18
|
+
"""One or multiple conversation messages."""
|
|
19
|
+
|
|
20
|
+
image: tv_tensors.Image
|
|
21
|
+
"""Image tensor."""
|
|
22
|
+
|
|
23
|
+
target: TargetType | None
|
|
24
|
+
"""Target data."""
|
|
25
|
+
|
|
26
|
+
metadata: dict[str, Any] | None
|
|
27
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Model module for vision-language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.metrics import structs as metrics_lib
|
|
10
|
+
from eva.core.models.modules import module
|
|
11
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VisionLanguageModule(module.ModelModule):
|
|
16
|
+
"""Model module for vision-language tasks."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
22
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the text inference module.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model instance to use for forward pass.
|
|
28
|
+
metrics: Metrics schema for evaluation.
|
|
29
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> List[str]:
|
|
37
|
+
return self.model(batch)
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def validation_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
41
|
+
return self._batch_step(batch)
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def test_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
45
|
+
return self._batch_step(batch)
|
|
46
|
+
|
|
47
|
+
def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
|
|
48
|
+
text, _, targets, metadata = TextImageBatch(*batch)
|
|
49
|
+
predictions = self.forward(batch)
|
|
50
|
+
return {
|
|
51
|
+
"inputs": text,
|
|
52
|
+
"predictions": predictions,
|
|
53
|
+
"targets": targets,
|
|
54
|
+
"metadata": metadata,
|
|
55
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Multimodal networks API."""
|
|
2
|
+
|
|
3
|
+
from eva.multimodal.models.networks.alibaba import Qwen25VL7BInstruct
|
|
4
|
+
from eva.multimodal.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
|
|
5
|
+
from eva.multimodal.models.networks.others import PathoR13b
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Claude35Sonnet20240620",
|
|
10
|
+
"Claude37Sonnet20250219",
|
|
11
|
+
"PathoR13b",
|
|
12
|
+
"Qwen25VL7BInstruct",
|
|
13
|
+
"model_registry",
|
|
14
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-5-vl-7b-instruct")
|
|
10
|
+
class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2.5-VL 7B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
system_prompt: str | None = None,
|
|
16
|
+
cache_dir: str | None = None,
|
|
17
|
+
attn_implementation: str = "flash_attention_2",
|
|
18
|
+
):
|
|
19
|
+
"""Initialize the model."""
|
|
20
|
+
super().__init__(
|
|
21
|
+
model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
|
|
22
|
+
model_class="Qwen2_5_VLForConditionalGeneration",
|
|
23
|
+
model_kwargs={
|
|
24
|
+
"torch_dtype": torch.bfloat16,
|
|
25
|
+
"trust_remote_code": True,
|
|
26
|
+
"cache_dir": cache_dir,
|
|
27
|
+
"attn_implementation": attn_implementation,
|
|
28
|
+
},
|
|
29
|
+
generation_kwargs={
|
|
30
|
+
"max_new_tokens": 512,
|
|
31
|
+
"do_sample": False,
|
|
32
|
+
},
|
|
33
|
+
processor_kwargs={
|
|
34
|
+
"padding": True,
|
|
35
|
+
"padding_side": "left",
|
|
36
|
+
"max_pixels": 451584, # 672*672
|
|
37
|
+
},
|
|
38
|
+
system_prompt=system_prompt,
|
|
39
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|