kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (101) hide show
  1. eva/core/callbacks/config.py +4 -0
  2. eva/core/cli/setup.py +1 -1
  3. eva/core/data/dataloaders/__init__.py +1 -2
  4. eva/core/data/dataloaders/dataloader.py +3 -1
  5. eva/core/data/samplers/random.py +17 -10
  6. eva/core/interface/interface.py +21 -0
  7. eva/core/loggers/log/__init__.py +2 -1
  8. eva/core/loggers/log/table.py +73 -0
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +2 -0
  16. eva/language/__init__.py +2 -1
  17. eva/language/callbacks/__init__.py +5 -0
  18. eva/language/callbacks/writers/__init__.py +5 -0
  19. eva/language/callbacks/writers/prediction.py +176 -0
  20. eva/language/data/dataloaders/__init__.py +5 -0
  21. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  22. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  23. eva/language/data/datasets/__init__.py +3 -1
  24. eva/language/data/datasets/{language.py → base.py} +1 -1
  25. eva/language/data/datasets/classification/base.py +3 -43
  26. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  27. eva/language/data/datasets/prediction.py +151 -0
  28. eva/language/data/datasets/schemas.py +18 -0
  29. eva/language/data/datasets/text.py +92 -0
  30. eva/language/data/datasets/typings.py +39 -0
  31. eva/language/data/messages.py +60 -0
  32. eva/language/models/__init__.py +15 -11
  33. eva/language/models/modules/__init__.py +2 -2
  34. eva/language/models/modules/language.py +93 -0
  35. eva/language/models/networks/__init__.py +12 -0
  36. eva/language/models/networks/alibaba.py +26 -0
  37. eva/language/models/networks/api/__init__.py +11 -0
  38. eva/language/models/networks/api/anthropic.py +34 -0
  39. eva/language/models/networks/registry.py +5 -0
  40. eva/language/models/typings.py +39 -0
  41. eva/language/models/wrappers/__init__.py +13 -5
  42. eva/language/models/wrappers/base.py +47 -0
  43. eva/language/models/wrappers/from_registry.py +54 -0
  44. eva/language/models/wrappers/huggingface.py +44 -8
  45. eva/language/models/wrappers/litellm.py +81 -46
  46. eva/language/models/wrappers/vllm.py +37 -13
  47. eva/language/utils/__init__.py +2 -1
  48. eva/language/utils/str_to_int_tensor.py +20 -12
  49. eva/language/utils/text/__init__.py +5 -0
  50. eva/language/utils/text/messages.py +113 -0
  51. eva/multimodal/__init__.py +6 -0
  52. eva/multimodal/callbacks/__init__.py +5 -0
  53. eva/multimodal/callbacks/writers/__init__.py +5 -0
  54. eva/multimodal/callbacks/writers/prediction.py +39 -0
  55. eva/multimodal/data/__init__.py +5 -0
  56. eva/multimodal/data/dataloaders/__init__.py +5 -0
  57. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  58. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  59. eva/multimodal/data/datasets/__init__.py +6 -0
  60. eva/multimodal/data/datasets/base.py +13 -0
  61. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  62. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  63. eva/multimodal/data/datasets/schemas.py +14 -0
  64. eva/multimodal/data/datasets/text_image.py +77 -0
  65. eva/multimodal/data/datasets/typings.py +27 -0
  66. eva/multimodal/models/__init__.py +8 -0
  67. eva/multimodal/models/modules/__init__.py +5 -0
  68. eva/multimodal/models/modules/vision_language.py +55 -0
  69. eva/multimodal/models/networks/__init__.py +14 -0
  70. eva/multimodal/models/networks/alibaba.py +39 -0
  71. eva/multimodal/models/networks/api/__init__.py +11 -0
  72. eva/multimodal/models/networks/api/anthropic.py +34 -0
  73. eva/multimodal/models/networks/others.py +47 -0
  74. eva/multimodal/models/networks/registry.py +5 -0
  75. eva/multimodal/models/typings.py +27 -0
  76. eva/multimodal/models/wrappers/__init__.py +13 -0
  77. eva/multimodal/models/wrappers/base.py +47 -0
  78. eva/multimodal/models/wrappers/from_registry.py +54 -0
  79. eva/multimodal/models/wrappers/huggingface.py +180 -0
  80. eva/multimodal/models/wrappers/litellm.py +56 -0
  81. eva/multimodal/utils/__init__.py +1 -0
  82. eva/multimodal/utils/image/__init__.py +5 -0
  83. eva/multimodal/utils/image/encode.py +28 -0
  84. eva/multimodal/utils/text/__init__.py +1 -0
  85. eva/multimodal/utils/text/messages.py +79 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  87. eva/vision/data/transforms/__init__.py +2 -1
  88. eva/vision/data/transforms/spatial/__init__.py +2 -1
  89. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  90. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  91. eva/vision/data/transforms/spatial/resize.py +62 -0
  92. eva/vision/models/wrappers/from_registry.py +6 -5
  93. eva/vision/models/wrappers/from_timm.py +6 -4
  94. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
  95. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
  96. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  97. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  98. eva/language/models/modules/text.py +0 -85
  99. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
  100. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
  101. {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,11 +16,11 @@ class CastStrToIntTensor:
16
16
  Supports single values, lists of strings, or lists of integers.
17
17
 
18
18
  Example:
19
- >>> # Default mapping for yes/no/maybe classification
20
- >>> transform = CastStrToIntTensor()
21
- >>> transform(['yes', 'no', 'maybe'])
19
+ >>> # Default mapping for A/B/C classification
20
+ >>> transform = CastStrToIntTensor(mapping={"A": 0, "B": 1, "C": 2})
21
+ >>> transform(['B', 'A', 'C'])
22
22
  tensor([1, 0, 2])
23
- >>> transform('yes')
23
+ >>> transform('B')
24
24
  tensor([1])
25
25
 
26
26
  >>> # Custom mapping
@@ -29,20 +29,25 @@ class CastStrToIntTensor:
29
29
  tensor([1, 0])
30
30
  """
31
31
 
32
- def __init__(self, mapping: Dict[str, int] | None = None):
33
- """Initialize the transform with a regex-to-integer mapping.
32
+ def __init__(
33
+ self, mapping: Dict[str, int], standalone_words: bool = True, case_sensitive: bool = True
34
+ ) -> None:
35
+ r"""Initialize the transform with a regex-to-integer mapping.
34
36
 
35
37
  Args:
36
38
  mapping: Dictionary mapping regex patterns to integers. If None, uses default
37
39
  yes/no/maybe mapping: {'no': 0, 'yes': 1, 'maybe': 2}
40
+ standalone_words: If True, patterns are treated as standalone words (e.g., '\bno\b').
41
+ case_sensitive: If True, regex patterns are case-sensitive.
38
42
  """
39
- if mapping is None:
40
- self.mapping = {r"\bno\b": 0, r"\byes\b": 1, r"\bmaybe\b": 2}
41
- else:
42
- self.mapping = mapping
43
+ self.mapping = mapping
44
+
45
+ if standalone_words:
46
+ self.mapping = {rf"\b{k}\b": v for k, v in mapping.items()}
43
47
 
44
48
  self.compiled_patterns = [
45
- (re.compile(pattern, re.IGNORECASE), value) for pattern, value in self.mapping.items()
49
+ (re.compile(pattern, 0 if case_sensitive else re.IGNORECASE), value)
50
+ for pattern, value in self.mapping.items()
46
51
  ]
47
52
 
48
53
  def __call__(self, values: Union[str, List[str], List[int]]) -> torch.Tensor:
@@ -58,7 +63,10 @@ class CastStrToIntTensor:
58
63
  ValueError: If any value cannot be mapped to an integer.
59
64
  """
60
65
  return torch.tensor(
61
- [self._cast_single(v) for v in (values if isinstance(values, list) else [values])],
66
+ [
67
+ self._cast_single(v)
68
+ for v in (values if isinstance(values, list | tuple) else [values])
69
+ ],
62
70
  dtype=torch.int,
63
71
  )
64
72
 
@@ -0,0 +1,5 @@
1
+ """Text utilities for language models."""
2
+
3
+ from eva.language.utils.text.messages import format_chat_message
4
+
5
+ __all__ = ["format_chat_message"]
@@ -0,0 +1,113 @@
1
+ """Message formatting utilities for language models."""
2
+
3
+ import functools
4
+ import json
5
+ from typing import Any, Dict, List
6
+
7
+ from eva.language.data.messages import (
8
+ AssistantMessage,
9
+ MessageSeries,
10
+ Role,
11
+ SystemMessage,
12
+ UserMessage,
13
+ )
14
+
15
+
16
+ def format_chat_message(message: MessageSeries) -> List[Dict[str, Any]]:
17
+ """Formats a message series into a format following OpenAI's API specification."""
18
+ return [{"role": item.role, "content": item.content} for item in message]
19
+
20
+
21
+ def combine_system_messages(message: MessageSeries, join_char: str = "\n") -> MessageSeries:
22
+ """Combine system messages into a single message.
23
+
24
+ This is useful when the MessageSeries contains multiple system messages such
25
+ as `ModelSystemMessage` and `TaskSystemMessage`. But given that most models / apis
26
+ expect a single system message, this function can be used to combines them into one.
27
+
28
+ Args:
29
+ message: The message series containing one or multiple messages.
30
+ join_char: The character to use to join the system messages. Default is newline.
31
+
32
+ Returns:
33
+ A new message series with system messages combined into one and the
34
+ remaining messages unchanged.
35
+ """
36
+ system_messages = list(filter(lambda item: item.role == Role.SYSTEM, message))
37
+ if len(system_messages) == 0:
38
+ return message
39
+
40
+ non_system_messages = list(filter(lambda item: item.role != Role.SYSTEM, message))
41
+ return [
42
+ SystemMessage(content=merge_message_contents(system_messages, join_char=join_char))
43
+ ] + non_system_messages
44
+
45
+
46
+ def merge_message_contents(message: MessageSeries, join_char: str = "\n") -> str:
47
+ """Merges the all contents within a message series into a string.
48
+
49
+ Args:
50
+ message: The message series to combine.
51
+ join_char: The character to use to join the message contents. Default is newline.
52
+
53
+ Returns:
54
+ A string containing the combined message contents.
55
+ """
56
+ return join_char.join(item.content for item in message)
57
+
58
+
59
+ def insert_system_message(
60
+ message: MessageSeries, system_message: SystemMessage | None
61
+ ) -> MessageSeries:
62
+ """Insert a system message at the beginning of the message series."""
63
+ if system_message is None:
64
+ return message
65
+ return [system_message] + message
66
+
67
+
68
+ def batch_insert_system_message(
69
+ messages: List[MessageSeries], system_message: SystemMessage | None
70
+ ) -> List[MessageSeries]:
71
+ """Insert a system message at the beginning of each message series in a batch."""
72
+ return list(
73
+ map(functools.partial(insert_system_message, system_message=system_message), messages)
74
+ )
75
+
76
+
77
+ def serialize(messages: MessageSeries) -> str:
78
+ """Serialize a MessageSeries object into a JSON string.
79
+
80
+ Args:
81
+ messages: A list of message objects (MessagesSeries).
82
+
83
+ Returns:
84
+ A JSON string representing the message series, with the following format:
85
+ [{"role": "user", "content": "Hello"}, ...]
86
+ """
87
+ serialized_messages = format_chat_message(messages)
88
+ return json.dumps(serialized_messages)
89
+
90
+
91
+ def deserialize(messages: str) -> MessageSeries:
92
+ """Convert a json string to a MessageSeries object.
93
+
94
+ Format: [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
95
+ """
96
+ message_dicts = json.loads(messages)
97
+
98
+ message_series = []
99
+ for message_dict in message_dicts:
100
+ if "role" not in message_dict or "content" not in message_dict:
101
+ raise ValueError("`role` or `content` keys are missing.")
102
+
103
+ match message_dict["role"]:
104
+ case Role.USER:
105
+ message_series.append(UserMessage(**message_dict))
106
+ case Role.ASSISTANT:
107
+ message_series.append(AssistantMessage(**message_dict))
108
+ case Role.SYSTEM:
109
+ message_series.append(SystemMessage(**message_dict))
110
+ case _:
111
+ raise ValueError(f"Unknown role: {message_dict['role']}")
112
+
113
+ return message_series
@@ -0,0 +1,6 @@
1
+ """Multimodal API."""
2
+
3
+ from eva.multimodal import models
4
+ from eva.multimodal.data import datasets
5
+
6
+ __all__ = ["models", "datasets"]
@@ -0,0 +1,5 @@
1
+ """Multimodal callbacks API."""
2
+
3
+ from eva.multimodal.callbacks.writers import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,5 @@
1
+ """Multimodal writers callbacks API."""
2
+
3
+ from eva.multimodal.callbacks.writers.prediction import TextPredictionWriter
4
+
5
+ __all__ = ["TextPredictionWriter"]
@@ -0,0 +1,39 @@
1
+ """Text prediction writer callbacks."""
2
+
3
+ from typing import Dict, List, Literal, Tuple
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.language.callbacks import writers
9
+ from eva.multimodal.models.typings import TextImageBatch
10
+
11
+
12
+ class TextPredictionWriter(writers.TextPredictionWriter):
13
+ """Callback for writing generated text predictions to disk."""
14
+
15
+ def __init__(
16
+ self,
17
+ output_dir: str,
18
+ model: nn.Module,
19
+ dataloader_idx_map: Dict[int, str] | None = None,
20
+ metadata_keys: List[str] | None = None,
21
+ include_input: bool = True,
22
+ overwrite: bool = False,
23
+ save_format: Literal["jsonl", "parquet", "csv"] = "jsonl",
24
+ ) -> None:
25
+ """See docstring of base class."""
26
+ super().__init__(
27
+ output_dir=output_dir,
28
+ model=model,
29
+ dataloader_idx_map=dataloader_idx_map,
30
+ metadata_keys=metadata_keys,
31
+ include_input=include_input,
32
+ overwrite=overwrite,
33
+ save_format=save_format,
34
+ )
35
+
36
+ @override
37
+ def _unpack_batch(self, batch: TextImageBatch) -> Tuple[list, list | None, dict | None]: # type: ignore
38
+ text_batch, _, target_batch, metadata_batch = TextImageBatch(*batch)
39
+ return text_batch, target_batch, metadata_batch
@@ -0,0 +1,5 @@
1
+ """Data components for multimodal learning."""
2
+
3
+ from eva.multimodal.data import datasets
4
+
5
+ __all__ = ["datasets"]
@@ -0,0 +1,5 @@
1
+ """Multimodal dataloaders API."""
2
+
3
+ from eva.multimodal.data.dataloaders.collate_fn import text_image_collate
4
+
5
+ __all__ = ["text_image_collate"]
@@ -0,0 +1,5 @@
1
+ """Multimodal collate functions API."""
2
+
3
+ from eva.multimodal.data.dataloaders.collate_fn.text_image import text_image_collate
4
+
5
+ __all__ = ["text_image_collate"]
@@ -0,0 +1,28 @@
1
+ """Collate functions for text-image data."""
2
+
3
+ from typing import List
4
+
5
+ from torch.utils.data._utils.collate import default_collate
6
+
7
+ from eva.multimodal.data.datasets.typings import TextImageSample
8
+ from eva.multimodal.models.typings import TextImageBatch
9
+
10
+
11
+ def text_image_collate(batch: List[TextImageSample]) -> TextImageBatch:
12
+ """Collate function for text-image batches."""
13
+ texts, images, targets, metadata = zip(*batch, strict=False)
14
+
15
+ first_sample = batch[0]
16
+ metadata = None
17
+ if first_sample.metadata is not None:
18
+ metadata = {
19
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
20
+ for k in first_sample.metadata.keys()
21
+ }
22
+
23
+ return TextImageBatch(
24
+ text=list(texts),
25
+ image=list(images),
26
+ target=default_collate(targets) if targets[0] is not None else None,
27
+ metadata=metadata,
28
+ )
@@ -0,0 +1,6 @@
1
+ """Multimodal datasets API."""
2
+
3
+ from eva.multimodal.data.datasets.multiple_choice.patch_camelyon import PatchCamelyon
4
+ from eva.multimodal.data.datasets.text_image import TextImageDataset
5
+
6
+ __all__ = ["TextImageDataset", "PatchCamelyon"]
@@ -0,0 +1,13 @@
1
+ """Multimodal Dataset base class."""
2
+
3
+ import abc
4
+ from typing import Generic, TypeVar
5
+
6
+ from eva.core.data.datasets import base
7
+
8
+ DataSample = TypeVar("DataSample")
9
+ """The data sample type."""
10
+
11
+
12
+ class MultimodalDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
+ """Base dataset class for multimodal tasks."""
@@ -0,0 +1,5 @@
1
+ """Multiple choice datasets."""
2
+
3
+ from eva.multimodal.data.datasets.multiple_choice.patch_camelyon import PatchCamelyon
4
+
5
+ __all__ = ["PatchCamelyon"]
@@ -0,0 +1,80 @@
1
+ """PatchCamelyon dataset with text prompts for multimodal classification."""
2
+
3
+ from typing import Any, Dict, Literal
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import override
7
+
8
+ from eva.language.data.messages import MessageSeries, UserMessage
9
+ from eva.multimodal.data.datasets.schemas import TransformsSchema
10
+ from eva.multimodal.data.datasets.text_image import TextImageDataset
11
+ from eva.vision.data import datasets as vision_datasets
12
+
13
+
14
+ class PatchCamelyon(TextImageDataset[int], vision_datasets.PatchCamelyon):
15
+ """PatchCamelyon image classification using a multiple choice text prompt."""
16
+
17
+ _default_prompt = (
18
+ "You are a pathology expert helping pathologists to analyze images of tissue samples.\n"
19
+ "Question: Does this image show metastatic breast tissue?\n"
20
+ "Options: A: no, B: yes\n"
21
+ "Only answer with a single letter without further explanation. "
22
+ "Please always provide an answer, even if you are not sure.\n"
23
+ "Answer: "
24
+ )
25
+
26
+ def __init__(
27
+ self,
28
+ root: str,
29
+ split: Literal["train", "val", "test"],
30
+ download: bool = False,
31
+ transforms: TransformsSchema | None = None,
32
+ prompt: str | None = None,
33
+ max_samples: int | None = None,
34
+ ) -> None:
35
+ """Initializes the dataset.
36
+
37
+ Args:
38
+ root: The path to the dataset root. This path should contain
39
+ the uncompressed h5 files and the metadata.
40
+ split: The dataset split for training, validation, or testing.
41
+ download: Whether to download the data for the specified split.
42
+ Note that the download will be executed only by additionally
43
+ calling the :meth:`prepare_data` method.
44
+ transforms: A function/transform which returns a transformed
45
+ version of the raw data samples.
46
+ prompt: The text prompt to use for classification (multple choice).
47
+ max_samples: Maximum number of samples to use. If None, use all samples.
48
+ """
49
+ super().__init__(root=root, split=split, download=download, transforms=transforms)
50
+
51
+ self.max_samples = max_samples
52
+ self.prompt = prompt or self._default_prompt
53
+
54
+ if self.max_samples is not None:
55
+ self._expected_length = {split: max_samples}
56
+
57
+ @property
58
+ @override
59
+ def class_to_idx(self) -> Dict[str, int]:
60
+ return {"A": 0, "B": 1}
61
+
62
+ @override
63
+ def __len__(self) -> int:
64
+ return self.max_samples or self._fetch_dataset_length()
65
+
66
+ @override
67
+ def load_text(self, index: int) -> MessageSeries:
68
+ return [UserMessage(content=self.prompt)]
69
+
70
+ @override
71
+ def load_image(self, index: int) -> tv_tensors.Image:
72
+ return vision_datasets.PatchCamelyon.load_data(self, index)
73
+
74
+ @override
75
+ def load_target(self, index: int) -> int:
76
+ return int(vision_datasets.PatchCamelyon.load_target(self, index).item())
77
+
78
+ @override
79
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
80
+ return vision_datasets.PatchCamelyon.load_metadata(self, index)
@@ -0,0 +1,14 @@
1
+ """Schema definitions for dataset classes."""
2
+
3
+ import dataclasses
4
+ from typing import Callable
5
+
6
+ from eva.language.data.datasets import schemas as language_schemas
7
+
8
+
9
+ @dataclasses.dataclass(frozen=True)
10
+ class TransformsSchema(language_schemas.TransformsSchema):
11
+ """Schema for dataset transforms."""
12
+
13
+ image: Callable | None = None
14
+ """Image transformation"""
@@ -0,0 +1,77 @@
1
+ """Base classes for text-image datasets."""
2
+
3
+ import abc
4
+ from typing import Generic
5
+
6
+ from torchvision import tv_tensors
7
+ from typing_extensions import override
8
+
9
+ from eva.language.data.datasets.text import TextDataset
10
+ from eva.multimodal.data.datasets.base import MultimodalDataset
11
+ from eva.multimodal.data.datasets.schemas import TransformsSchema
12
+ from eva.multimodal.data.datasets.typings import TargetType, TextImageSample
13
+
14
+
15
+ class TextImageDataset(
16
+ MultimodalDataset[TextImageSample[TargetType]], TextDataset, abc.ABC, Generic[TargetType]
17
+ ):
18
+ """Base dataset class for text-image tasks."""
19
+
20
+ def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
21
+ """Initializes the dataset.
22
+
23
+ Args:
24
+ *args: Positional arguments for the base class.
25
+ transforms: The transforms to apply to the text, image and target when
26
+ loading the samples.
27
+ **kwargs: Keyword arguments for the base class.
28
+ """
29
+ super().__init__(*args, **kwargs)
30
+
31
+ self.transforms = transforms
32
+
33
+ @abc.abstractmethod
34
+ def load_image(self, index: int) -> tv_tensors.Image:
35
+ """Returns the image content.
36
+
37
+ Args:
38
+ index: The index of the data sample.
39
+
40
+ Returns:
41
+ The image content.
42
+ """
43
+ raise NotImplementedError
44
+
45
+ @override
46
+ def __getitem__(self, index: int) -> TextImageSample[TargetType]:
47
+ item = TextImageSample(
48
+ text=self.load_text(index),
49
+ image=self.load_image(index),
50
+ target=self.load_target(index),
51
+ metadata=self.load_metadata(index) or {},
52
+ )
53
+ return self._apply_transforms(item)
54
+
55
+ @override
56
+ def _apply_transforms(self, sample: TextImageSample[TargetType]) -> TextImageSample[TargetType]:
57
+ """Applies the dataset transforms to the text, image and target.
58
+
59
+ Args:
60
+ sample: The sample containing text, image, target and metadata.
61
+
62
+ Returns:
63
+ The transformed sample.
64
+ """
65
+ if self.transforms:
66
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
67
+ image = self.transforms.image(sample.image) if self.transforms.image else sample.image
68
+ target = (
69
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
70
+ )
71
+ return TextImageSample(
72
+ text=text,
73
+ image=image,
74
+ target=target,
75
+ metadata=sample.metadata,
76
+ )
77
+ return sample
@@ -0,0 +1,27 @@
1
+ """Typings for multimodal datasets."""
2
+
3
+ from typing import Any, Generic, TypeVar
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import NamedTuple
7
+
8
+ from eva.language.data.messages import MessageSeries
9
+
10
+ TargetType = TypeVar("TargetType")
11
+ """The target data type."""
12
+
13
+
14
+ class TextImageSample(NamedTuple, Generic[TargetType]):
15
+ """Text and image sample with target and metadata."""
16
+
17
+ text: MessageSeries
18
+ """One or multiple conversation messages."""
19
+
20
+ image: tv_tensors.Image
21
+ """Image tensor."""
22
+
23
+ target: TargetType | None
24
+ """Target data."""
25
+
26
+ metadata: dict[str, Any] | None
27
+ """Additional metadata."""
@@ -0,0 +1,8 @@
1
+ """Multimodal models API."""
2
+
3
+ from eva.multimodal.models import networks, wrappers
4
+
5
+ __all__ = [
6
+ "networks",
7
+ "wrappers",
8
+ ]
@@ -0,0 +1,5 @@
1
+ """Multimodal Networks API."""
2
+
3
+ from eva.multimodal.models.modules.vision_language import VisionLanguageModule
4
+
5
+ __all__ = ["VisionLanguageModule"]
@@ -0,0 +1,55 @@
1
+ """Model module for vision-language models."""
2
+
3
+ from typing import Any, List
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.metrics import structs as metrics_lib
10
+ from eva.core.models.modules import module
11
+ from eva.core.models.modules.utils import batch_postprocess
12
+ from eva.multimodal.models.typings import TextImageBatch
13
+
14
+
15
+ class VisionLanguageModule(module.ModelModule):
16
+ """Model module for vision-language tasks."""
17
+
18
+ def __init__(
19
+ self,
20
+ model: nn.Module,
21
+ metrics: metrics_lib.MetricsSchema | None = None,
22
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
23
+ ) -> None:
24
+ """Initializes the text inference module.
25
+
26
+ Args:
27
+ model: Model instance to use for forward pass.
28
+ metrics: Metrics schema for evaluation.
29
+ postprocess: A helper function to post-process model outputs before evaluation.
30
+ """
31
+ super().__init__(metrics=metrics, postprocess=postprocess)
32
+
33
+ self.model = model
34
+
35
+ @override
36
+ def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> List[str]:
37
+ return self.model(batch)
38
+
39
+ @override
40
+ def validation_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
41
+ return self._batch_step(batch)
42
+
43
+ @override
44
+ def test_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
45
+ return self._batch_step(batch)
46
+
47
+ def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
48
+ text, _, targets, metadata = TextImageBatch(*batch)
49
+ predictions = self.forward(batch)
50
+ return {
51
+ "inputs": text,
52
+ "predictions": predictions,
53
+ "targets": targets,
54
+ "metadata": metadata,
55
+ }
@@ -0,0 +1,14 @@
1
+ """Multimodal networks API."""
2
+
3
+ from eva.multimodal.models.networks.alibaba import Qwen25VL7BInstruct
4
+ from eva.multimodal.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
5
+ from eva.multimodal.models.networks.others import PathoR13b
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ "PathoR13b",
12
+ "Qwen25VL7BInstruct",
13
+ "model_registry",
14
+ ]
@@ -0,0 +1,39 @@
1
+ """Models from Alibaba."""
2
+
3
+ import torch
4
+
5
+ from eva.multimodal.models import wrappers
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+
9
+ @model_registry.register("alibaba/qwen2-5-vl-7b-instruct")
10
+ class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
11
+ """Qwen2.5-VL 7B Instruct model."""
12
+
13
+ def __init__(
14
+ self,
15
+ system_prompt: str | None = None,
16
+ cache_dir: str | None = None,
17
+ attn_implementation: str = "flash_attention_2",
18
+ ):
19
+ """Initialize the model."""
20
+ super().__init__(
21
+ model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
22
+ model_class="Qwen2_5_VLForConditionalGeneration",
23
+ model_kwargs={
24
+ "torch_dtype": torch.bfloat16,
25
+ "trust_remote_code": True,
26
+ "cache_dir": cache_dir,
27
+ "attn_implementation": attn_implementation,
28
+ },
29
+ generation_kwargs={
30
+ "max_new_tokens": 512,
31
+ "do_sample": False,
32
+ },
33
+ processor_kwargs={
34
+ "padding": True,
35
+ "padding_side": "left",
36
+ "max_pixels": 451584, # 672*672
37
+ },
38
+ system_prompt=system_prompt,
39
+ )
@@ -0,0 +1,11 @@
1
+ """Multimodal API networks."""
2
+
3
+ from eva.multimodal.models.networks.api.anthropic import (
4
+ Claude35Sonnet20240620,
5
+ Claude37Sonnet20250219,
6
+ )
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ ]