kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/dataloaders/dataloader.py +3 -1
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/log/__init__.py +2 -1
- eva/core/loggers/log/table.py +73 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,11 +16,11 @@ class CastStrToIntTensor:
|
|
|
16
16
|
Supports single values, lists of strings, or lists of integers.
|
|
17
17
|
|
|
18
18
|
Example:
|
|
19
|
-
>>> # Default mapping for
|
|
20
|
-
>>> transform = CastStrToIntTensor()
|
|
21
|
-
>>> transform(['
|
|
19
|
+
>>> # Default mapping for A/B/C classification
|
|
20
|
+
>>> transform = CastStrToIntTensor(mapping={"A": 0, "B": 1, "C": 2})
|
|
21
|
+
>>> transform(['B', 'A', 'C'])
|
|
22
22
|
tensor([1, 0, 2])
|
|
23
|
-
>>> transform('
|
|
23
|
+
>>> transform('B')
|
|
24
24
|
tensor([1])
|
|
25
25
|
|
|
26
26
|
>>> # Custom mapping
|
|
@@ -29,20 +29,25 @@ class CastStrToIntTensor:
|
|
|
29
29
|
tensor([1, 0])
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
|
-
def __init__(
|
|
33
|
-
|
|
32
|
+
def __init__(
|
|
33
|
+
self, mapping: Dict[str, int], standalone_words: bool = True, case_sensitive: bool = True
|
|
34
|
+
) -> None:
|
|
35
|
+
r"""Initialize the transform with a regex-to-integer mapping.
|
|
34
36
|
|
|
35
37
|
Args:
|
|
36
38
|
mapping: Dictionary mapping regex patterns to integers. If None, uses default
|
|
37
39
|
yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
|
|
40
|
+
standalone_words: If True, patterns are treated as standalone words (e.g., '\bno\b').
|
|
41
|
+
case_sensitive: If True, regex patterns are case-sensitive.
|
|
38
42
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
self.mapping = mapping
|
|
43
|
+
self.mapping = mapping
|
|
44
|
+
|
|
45
|
+
if standalone_words:
|
|
46
|
+
self.mapping = {rf"\b{k}\b": v for k, v in mapping.items()}
|
|
43
47
|
|
|
44
48
|
self.compiled_patterns = [
|
|
45
|
-
(re.compile(pattern, re.IGNORECASE), value)
|
|
49
|
+
(re.compile(pattern, 0 if case_sensitive else re.IGNORECASE), value)
|
|
50
|
+
for pattern, value in self.mapping.items()
|
|
46
51
|
]
|
|
47
52
|
|
|
48
53
|
def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
|
|
@@ -58,7 +63,10 @@ class CastStrToIntTensor:
|
|
|
58
63
|
ValueError: If any value cannot be mapped to an integer.
|
|
59
64
|
"""
|
|
60
65
|
return torch.tensor(
|
|
61
|
-
[
|
|
66
|
+
[
|
|
67
|
+
self._cast_single(v)
|
|
68
|
+
for v in (values if isinstance(values, list | tuple) else [values])
|
|
69
|
+
],
|
|
62
70
|
dtype=torch.int,
|
|
63
71
|
)
|
|
64
72
|
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Message formatting utilities for language models."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import (
|
|
8
|
+
AssistantMessage,
|
|
9
|
+
MessageSeries,
|
|
10
|
+
Role,
|
|
11
|
+
SystemMessage,
|
|
12
|
+
UserMessage,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def format_chat_message(message: MessageSeries) -> List[Dict[str, Any]]:
|
|
17
|
+
"""Formats a message series into a format following OpenAI's API specification."""
|
|
18
|
+
return [{"role": item.role, "content": item.content} for item in message]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def combine_system_messages(message: MessageSeries, join_char: str = "\n") -> MessageSeries:
|
|
22
|
+
"""Combine system messages into a single message.
|
|
23
|
+
|
|
24
|
+
This is useful when the MessageSeries contains multiple system messages such
|
|
25
|
+
as `ModelSystemMessage` and `TaskSystemMessage`. But given that most models / apis
|
|
26
|
+
expect a single system message, this function can be used to combines them into one.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
message: The message series containing one or multiple messages.
|
|
30
|
+
join_char: The character to use to join the system messages. Default is newline.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A new message series with system messages combined into one and the
|
|
34
|
+
remaining messages unchanged.
|
|
35
|
+
"""
|
|
36
|
+
system_messages = list(filter(lambda item: item.role == Role.SYSTEM, message))
|
|
37
|
+
if len(system_messages) == 0:
|
|
38
|
+
return message
|
|
39
|
+
|
|
40
|
+
non_system_messages = list(filter(lambda item: item.role != Role.SYSTEM, message))
|
|
41
|
+
return [
|
|
42
|
+
SystemMessage(content=merge_message_contents(system_messages, join_char=join_char))
|
|
43
|
+
] + non_system_messages
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def merge_message_contents(message: MessageSeries, join_char: str = "\n") -> str:
|
|
47
|
+
"""Merges the all contents within a message series into a string.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
message: The message series to combine.
|
|
51
|
+
join_char: The character to use to join the message contents. Default is newline.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A string containing the combined message contents.
|
|
55
|
+
"""
|
|
56
|
+
return join_char.join(item.content for item in message)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def insert_system_message(
|
|
60
|
+
message: MessageSeries, system_message: SystemMessage | None
|
|
61
|
+
) -> MessageSeries:
|
|
62
|
+
"""Insert a system message at the beginning of the message series."""
|
|
63
|
+
if system_message is None:
|
|
64
|
+
return message
|
|
65
|
+
return [system_message] + message
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def batch_insert_system_message(
|
|
69
|
+
messages: List[MessageSeries], system_message: SystemMessage | None
|
|
70
|
+
) -> List[MessageSeries]:
|
|
71
|
+
"""Insert a system message at the beginning of each message series in a batch."""
|
|
72
|
+
return list(
|
|
73
|
+
map(functools.partial(insert_system_message, system_message=system_message), messages)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def serialize(messages: MessageSeries) -> str:
|
|
78
|
+
"""Serialize a MessageSeries object into a JSON string.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: A list of message objects (MessagesSeries).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A JSON string representing the message series, with the following format:
|
|
85
|
+
[{"role": "user", "content": "Hello"}, ...]
|
|
86
|
+
"""
|
|
87
|
+
serialized_messages = format_chat_message(messages)
|
|
88
|
+
return json.dumps(serialized_messages)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def deserialize(messages: str) -> MessageSeries:
|
|
92
|
+
"""Convert a json string to a MessageSeries object.
|
|
93
|
+
|
|
94
|
+
Format: [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
|
|
95
|
+
"""
|
|
96
|
+
message_dicts = json.loads(messages)
|
|
97
|
+
|
|
98
|
+
message_series = []
|
|
99
|
+
for message_dict in message_dicts:
|
|
100
|
+
if "role" not in message_dict or "content" not in message_dict:
|
|
101
|
+
raise ValueError("`role` or `content` keys are missing.")
|
|
102
|
+
|
|
103
|
+
match message_dict["role"]:
|
|
104
|
+
case Role.USER:
|
|
105
|
+
message_series.append(UserMessage(**message_dict))
|
|
106
|
+
case Role.ASSISTANT:
|
|
107
|
+
message_series.append(AssistantMessage(**message_dict))
|
|
108
|
+
case Role.SYSTEM:
|
|
109
|
+
message_series.append(SystemMessage(**message_dict))
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"Unknown role: {message_dict['role']}")
|
|
112
|
+
|
|
113
|
+
return message_series
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Text prediction writer callbacks."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.callbacks import writers
|
|
9
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextPredictionWriter(writers.TextPredictionWriter):
|
|
13
|
+
"""Callback for writing generated text predictions to disk."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
output_dir: str,
|
|
18
|
+
model: nn.Module,
|
|
19
|
+
dataloader_idx_map: Dict[int, str] | None = None,
|
|
20
|
+
metadata_keys: List[str] | None = None,
|
|
21
|
+
include_input: bool = True,
|
|
22
|
+
overwrite: bool = False,
|
|
23
|
+
save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
|
|
24
|
+
) -> None:
|
|
25
|
+
"""See docstring of base class."""
|
|
26
|
+
super().__init__(
|
|
27
|
+
output_dir=output_dir,
|
|
28
|
+
model=model,
|
|
29
|
+
dataloader_idx_map=dataloader_idx_map,
|
|
30
|
+
metadata_keys=metadata_keys,
|
|
31
|
+
include_input=include_input,
|
|
32
|
+
overwrite=overwrite,
|
|
33
|
+
save_format=save_format,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def _unpack_batch(self, batch: TextImageBatch) -> Tuple[list, list | None, dict | None]: # type: ignore
|
|
38
|
+
text_batch, _, target_batch, metadata_batch = TextImageBatch(*batch)
|
|
39
|
+
return text_batch, target_batch, metadata_batch
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Collate functions for text-image data."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from torch.utils.data._utils.collate import default_collate
|
|
6
|
+
|
|
7
|
+
from eva.multimodal.data.datasets.typings import TextImageSample
|
|
8
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def text_image_collate(batch: List[TextImageSample]) -> TextImageBatch:
|
|
12
|
+
"""Collate function for text-image batches."""
|
|
13
|
+
texts, images, targets, metadata = zip(*batch, strict=False)
|
|
14
|
+
|
|
15
|
+
first_sample = batch[0]
|
|
16
|
+
metadata = None
|
|
17
|
+
if first_sample.metadata is not None:
|
|
18
|
+
metadata = {
|
|
19
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
20
|
+
for k in first_sample.metadata.keys()
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
return TextImageBatch(
|
|
24
|
+
text=list(texts),
|
|
25
|
+
image=list(images),
|
|
26
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
27
|
+
metadata=metadata,
|
|
28
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Multimodal Dataset base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from eva.core.data.datasets import base
|
|
7
|
+
|
|
8
|
+
DataSample = TypeVar("DataSample")
|
|
9
|
+
"""The data sample type."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultimodalDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
13
|
+
"""Base dataset class for multimodal tasks."""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""PatchCamelyon dataset with text prompts for multimodal classification."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Literal
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
9
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
10
|
+
from eva.multimodal.data.datasets.text_image import TextImageDataset
|
|
11
|
+
from eva.vision.data import datasets as vision_datasets
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PatchCamelyon(TextImageDataset[int], vision_datasets.PatchCamelyon):
|
|
15
|
+
"""PatchCamelyon image classification using a multiple choice text prompt."""
|
|
16
|
+
|
|
17
|
+
_default_prompt = (
|
|
18
|
+
"You are a pathology expert helping pathologists to analyze images of tissue samples.\n"
|
|
19
|
+
"Question: Does this image show metastatic breast tissue?\n"
|
|
20
|
+
"Options: A: no, B: yes\n"
|
|
21
|
+
"Only answer with a single letter without further explanation. "
|
|
22
|
+
"Please always provide an answer, even if you are not sure.\n"
|
|
23
|
+
"Answer: "
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
root: str,
|
|
29
|
+
split: Literal["train", "val", "test"],
|
|
30
|
+
download: bool = False,
|
|
31
|
+
transforms: TransformsSchema | None = None,
|
|
32
|
+
prompt: str | None = None,
|
|
33
|
+
max_samples: int | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initializes the dataset.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
root: The path to the dataset root. This path should contain
|
|
39
|
+
the uncompressed h5 files and the metadata.
|
|
40
|
+
split: The dataset split for training, validation, or testing.
|
|
41
|
+
download: Whether to download the data for the specified split.
|
|
42
|
+
Note that the download will be executed only by additionally
|
|
43
|
+
calling the :meth:`prepare_data` method.
|
|
44
|
+
transforms: A function/transform which returns a transformed
|
|
45
|
+
version of the raw data samples.
|
|
46
|
+
prompt: The text prompt to use for classification (multple choice).
|
|
47
|
+
max_samples: Maximum number of samples to use. If None, use all samples.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(root=root, split=split, download=download, transforms=transforms)
|
|
50
|
+
|
|
51
|
+
self.max_samples = max_samples
|
|
52
|
+
self.prompt = prompt or self._default_prompt
|
|
53
|
+
|
|
54
|
+
if self.max_samples is not None:
|
|
55
|
+
self._expected_length = {split: max_samples}
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
@override
|
|
59
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
60
|
+
return {"A": 0, "B": 1}
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def __len__(self) -> int:
|
|
64
|
+
return self.max_samples or self._fetch_dataset_length()
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
68
|
+
return [UserMessage(content=self.prompt)]
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
72
|
+
return vision_datasets.PatchCamelyon.load_data(self, index)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def load_target(self, index: int) -> int:
|
|
76
|
+
return int(vision_datasets.PatchCamelyon.load_target(self, index).item())
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
80
|
+
return vision_datasets.PatchCamelyon.load_metadata(self, index)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Schema definitions for dataset classes."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from eva.language.data.datasets import schemas as language_schemas
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclasses.dataclass(frozen=True)
|
|
10
|
+
class TransformsSchema(language_schemas.TransformsSchema):
|
|
11
|
+
"""Schema for dataset transforms."""
|
|
12
|
+
|
|
13
|
+
image: Callable | None = None
|
|
14
|
+
"""Image transformation"""
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Base classes for text-image datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic
|
|
5
|
+
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.language.data.datasets.text import TextDataset
|
|
10
|
+
from eva.multimodal.data.datasets.base import MultimodalDataset
|
|
11
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
12
|
+
from eva.multimodal.data.datasets.typings import TargetType, TextImageSample
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TextImageDataset(
|
|
16
|
+
MultimodalDataset[TextImageSample[TargetType]], TextDataset, abc.ABC, Generic[TargetType]
|
|
17
|
+
):
|
|
18
|
+
"""Base dataset class for text-image tasks."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
|
|
21
|
+
"""Initializes the dataset.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
*args: Positional arguments for the base class.
|
|
25
|
+
transforms: The transforms to apply to the text, image and target when
|
|
26
|
+
loading the samples.
|
|
27
|
+
**kwargs: Keyword arguments for the base class.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__(*args, **kwargs)
|
|
30
|
+
|
|
31
|
+
self.transforms = transforms
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
35
|
+
"""Returns the image content.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
index: The index of the data sample.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The image content.
|
|
42
|
+
"""
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def __getitem__(self, index: int) -> TextImageSample[TargetType]:
|
|
47
|
+
item = TextImageSample(
|
|
48
|
+
text=self.load_text(index),
|
|
49
|
+
image=self.load_image(index),
|
|
50
|
+
target=self.load_target(index),
|
|
51
|
+
metadata=self.load_metadata(index) or {},
|
|
52
|
+
)
|
|
53
|
+
return self._apply_transforms(item)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def _apply_transforms(self, sample: TextImageSample[TargetType]) -> TextImageSample[TargetType]:
|
|
57
|
+
"""Applies the dataset transforms to the text, image and target.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
sample: The sample containing text, image, target and metadata.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The transformed sample.
|
|
64
|
+
"""
|
|
65
|
+
if self.transforms:
|
|
66
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
67
|
+
image = self.transforms.image(sample.image) if self.transforms.image else sample.image
|
|
68
|
+
target = (
|
|
69
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
70
|
+
)
|
|
71
|
+
return TextImageSample(
|
|
72
|
+
text=text,
|
|
73
|
+
image=image,
|
|
74
|
+
target=target,
|
|
75
|
+
metadata=sample.metadata,
|
|
76
|
+
)
|
|
77
|
+
return sample
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Typings for multimodal datasets."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import NamedTuple
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextImageSample(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text and image sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: MessageSeries
|
|
18
|
+
"""One or multiple conversation messages."""
|
|
19
|
+
|
|
20
|
+
image: tv_tensors.Image
|
|
21
|
+
"""Image tensor."""
|
|
22
|
+
|
|
23
|
+
target: TargetType | None
|
|
24
|
+
"""Target data."""
|
|
25
|
+
|
|
26
|
+
metadata: dict[str, Any] | None
|
|
27
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Model module for vision-language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.metrics import structs as metrics_lib
|
|
10
|
+
from eva.core.models.modules import module
|
|
11
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VisionLanguageModule(module.ModelModule):
|
|
16
|
+
"""Model module for vision-language tasks."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
22
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the text inference module.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model instance to use for forward pass.
|
|
28
|
+
metrics: Metrics schema for evaluation.
|
|
29
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> List[str]:
|
|
37
|
+
return self.model(batch)
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def validation_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
41
|
+
return self._batch_step(batch)
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def test_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
45
|
+
return self._batch_step(batch)
|
|
46
|
+
|
|
47
|
+
def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
|
|
48
|
+
text, _, targets, metadata = TextImageBatch(*batch)
|
|
49
|
+
predictions = self.forward(batch)
|
|
50
|
+
return {
|
|
51
|
+
"inputs": text,
|
|
52
|
+
"predictions": predictions,
|
|
53
|
+
"targets": targets,
|
|
54
|
+
"metadata": metadata,
|
|
55
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Multimodal networks API."""
|
|
2
|
+
|
|
3
|
+
from eva.multimodal.models.networks.alibaba import Qwen25VL7BInstruct
|
|
4
|
+
from eva.multimodal.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
|
|
5
|
+
from eva.multimodal.models.networks.others import PathoR13b
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Claude35Sonnet20240620",
|
|
10
|
+
"Claude37Sonnet20250219",
|
|
11
|
+
"PathoR13b",
|
|
12
|
+
"Qwen25VL7BInstruct",
|
|
13
|
+
"model_registry",
|
|
14
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-5-vl-7b-instruct")
|
|
10
|
+
class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2.5-VL 7B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
system_prompt: str | None = None,
|
|
16
|
+
cache_dir: str | None = None,
|
|
17
|
+
attn_implementation: str = "flash_attention_2",
|
|
18
|
+
):
|
|
19
|
+
"""Initialize the model."""
|
|
20
|
+
super().__init__(
|
|
21
|
+
model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
|
|
22
|
+
model_class="Qwen2_5_VLForConditionalGeneration",
|
|
23
|
+
model_kwargs={
|
|
24
|
+
"torch_dtype": torch.bfloat16,
|
|
25
|
+
"trust_remote_code": True,
|
|
26
|
+
"cache_dir": cache_dir,
|
|
27
|
+
"attn_implementation": attn_implementation,
|
|
28
|
+
},
|
|
29
|
+
generation_kwargs={
|
|
30
|
+
"max_new_tokens": 512,
|
|
31
|
+
"do_sample": False,
|
|
32
|
+
},
|
|
33
|
+
processor_kwargs={
|
|
34
|
+
"padding": True,
|
|
35
|
+
"padding_side": "left",
|
|
36
|
+
"max_pixels": 451584, # 672*672
|
|
37
|
+
},
|
|
38
|
+
system_prompt=system_prompt,
|
|
39
|
+
)
|