kaiko-eva 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (98) hide show
  1. eva/core/callbacks/config.py +4 -0
  2. eva/core/cli/setup.py +1 -1
  3. eva/core/data/dataloaders/__init__.py +1 -2
  4. eva/core/data/samplers/random.py +17 -10
  5. eva/core/interface/interface.py +21 -0
  6. eva/core/models/modules/module.py +2 -2
  7. eva/core/models/wrappers/base.py +2 -2
  8. eva/core/models/wrappers/from_function.py +3 -3
  9. eva/core/models/wrappers/from_torchhub.py +9 -7
  10. eva/core/models/wrappers/huggingface.py +4 -5
  11. eva/core/models/wrappers/onnx.py +5 -5
  12. eva/core/trainers/trainer.py +2 -0
  13. eva/language/__init__.py +2 -1
  14. eva/language/callbacks/__init__.py +5 -0
  15. eva/language/callbacks/writers/__init__.py +5 -0
  16. eva/language/callbacks/writers/prediction.py +176 -0
  17. eva/language/data/dataloaders/__init__.py +5 -0
  18. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  19. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  20. eva/language/data/datasets/__init__.py +3 -1
  21. eva/language/data/datasets/{language.py → base.py} +1 -1
  22. eva/language/data/datasets/classification/base.py +3 -43
  23. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  24. eva/language/data/datasets/prediction.py +151 -0
  25. eva/language/data/datasets/schemas.py +18 -0
  26. eva/language/data/datasets/text.py +92 -0
  27. eva/language/data/datasets/typings.py +39 -0
  28. eva/language/data/messages.py +60 -0
  29. eva/language/models/__init__.py +15 -11
  30. eva/language/models/modules/__init__.py +2 -2
  31. eva/language/models/modules/language.py +93 -0
  32. eva/language/models/networks/__init__.py +12 -0
  33. eva/language/models/networks/alibaba.py +26 -0
  34. eva/language/models/networks/api/__init__.py +11 -0
  35. eva/language/models/networks/api/anthropic.py +34 -0
  36. eva/language/models/networks/registry.py +5 -0
  37. eva/language/models/typings.py +39 -0
  38. eva/language/models/wrappers/__init__.py +13 -5
  39. eva/language/models/wrappers/base.py +47 -0
  40. eva/language/models/wrappers/from_registry.py +54 -0
  41. eva/language/models/wrappers/huggingface.py +44 -8
  42. eva/language/models/wrappers/litellm.py +81 -46
  43. eva/language/models/wrappers/vllm.py +37 -13
  44. eva/language/utils/__init__.py +2 -1
  45. eva/language/utils/str_to_int_tensor.py +20 -12
  46. eva/language/utils/text/__init__.py +5 -0
  47. eva/language/utils/text/messages.py +113 -0
  48. eva/multimodal/__init__.py +6 -0
  49. eva/multimodal/callbacks/__init__.py +5 -0
  50. eva/multimodal/callbacks/writers/__init__.py +5 -0
  51. eva/multimodal/callbacks/writers/prediction.py +39 -0
  52. eva/multimodal/data/__init__.py +5 -0
  53. eva/multimodal/data/dataloaders/__init__.py +5 -0
  54. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  55. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  56. eva/multimodal/data/datasets/__init__.py +6 -0
  57. eva/multimodal/data/datasets/base.py +13 -0
  58. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  59. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  60. eva/multimodal/data/datasets/schemas.py +14 -0
  61. eva/multimodal/data/datasets/text_image.py +77 -0
  62. eva/multimodal/data/datasets/typings.py +27 -0
  63. eva/multimodal/models/__init__.py +8 -0
  64. eva/multimodal/models/modules/__init__.py +5 -0
  65. eva/multimodal/models/modules/vision_language.py +55 -0
  66. eva/multimodal/models/networks/__init__.py +14 -0
  67. eva/multimodal/models/networks/alibaba.py +39 -0
  68. eva/multimodal/models/networks/api/__init__.py +11 -0
  69. eva/multimodal/models/networks/api/anthropic.py +34 -0
  70. eva/multimodal/models/networks/others.py +47 -0
  71. eva/multimodal/models/networks/registry.py +5 -0
  72. eva/multimodal/models/typings.py +27 -0
  73. eva/multimodal/models/wrappers/__init__.py +13 -0
  74. eva/multimodal/models/wrappers/base.py +47 -0
  75. eva/multimodal/models/wrappers/from_registry.py +54 -0
  76. eva/multimodal/models/wrappers/huggingface.py +180 -0
  77. eva/multimodal/models/wrappers/litellm.py +56 -0
  78. eva/multimodal/utils/__init__.py +1 -0
  79. eva/multimodal/utils/image/__init__.py +5 -0
  80. eva/multimodal/utils/image/encode.py +28 -0
  81. eva/multimodal/utils/text/__init__.py +1 -0
  82. eva/multimodal/utils/text/messages.py +79 -0
  83. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  84. eva/vision/data/transforms/__init__.py +2 -1
  85. eva/vision/data/transforms/spatial/__init__.py +2 -1
  86. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  87. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  88. eva/vision/data/transforms/spatial/resize.py +62 -0
  89. eva/vision/models/wrappers/from_registry.py +6 -5
  90. eva/vision/models/wrappers/from_timm.py +6 -4
  91. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
  92. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +95 -38
  93. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  94. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  95. eva/language/models/modules/text.py +0 -85
  96. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
  97. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
  98. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -10,11 +10,20 @@ from loguru import logger
10
10
  from typing_extensions import override
11
11
 
12
12
  from eva.language.data.datasets.classification import base
13
+ from eva.language.data.messages import MessageSeries, UserMessage
13
14
 
14
15
 
15
16
  class PubMedQA(base.TextClassification):
16
17
  """Dataset class for PubMedQA question answering task."""
17
18
 
19
+ _expected_dataset_lengths: Dict[str | None, int] = {
20
+ "train": 450,
21
+ "val": 50,
22
+ "test": 500,
23
+ None: 500,
24
+ }
25
+ """Expected dataset lengths for the splits and complete dataset."""
26
+
18
27
  _license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
19
28
  """Dataset license."""
20
29
 
@@ -52,7 +61,14 @@ class PubMedQA(base.TextClassification):
52
61
  """
53
62
  dataset_name = "bigbio/pubmed_qa"
54
63
  config_name = "pubmed_qa_labeled_fold0_source"
55
- split = (self._split or "train+test+validation") if self._split != "val" else "validation"
64
+
65
+ match self._split:
66
+ case "val":
67
+ split = "validation"
68
+ case None:
69
+ split = "train+test+validation"
70
+ case _:
71
+ split = self._split
56
72
 
57
73
  if self._download:
58
74
  logger.info("Downloading dataset from HuggingFace Hub")
@@ -88,7 +104,7 @@ class PubMedQA(base.TextClassification):
88
104
  dataset_path = None
89
105
 
90
106
  if self._root:
91
- dataset_path = self._root
107
+ dataset_path = os.path.join(self._root, self._split) if self._split else self._root
92
108
  os.makedirs(self._root, exist_ok=True)
93
109
 
94
110
  try:
@@ -103,6 +119,15 @@ class PubMedQA(base.TextClassification):
103
119
  except Exception as e:
104
120
  raise RuntimeError(f"Failed to prepare dataset: {e}") from e
105
121
 
122
+ @override
123
+ def validate(self) -> None:
124
+ if len(self) != self._expected_dataset_lengths[self._split]:
125
+ raise ValueError(
126
+ f"Dataset length mismatch for split '{self._split}': "
127
+ f"expected {self._expected_dataset_lengths[self._split]}, "
128
+ f"but got {len(self)}"
129
+ )
130
+
106
131
  @property
107
132
  @override
108
133
  def classes(self) -> List[str]:
@@ -114,11 +139,18 @@ class PubMedQA(base.TextClassification):
114
139
  return {"no": 0, "yes": 1, "maybe": 2}
115
140
 
116
141
  @override
117
- def load_text(self, index: int) -> str:
142
+ def load_text(self, index: int) -> MessageSeries:
118
143
  if index < 0 or index >= len(self.dataset):
119
144
  raise IndexError(f"Index {index} out of range for dataset of size {len(self.dataset)}")
120
145
  sample = dict(self.dataset[index])
121
- return f"Question: {sample['QUESTION']}\nContext: " + " ".join(sample["CONTEXTS"])
146
+ return [
147
+ UserMessage(
148
+ content=f"Question: {sample['QUESTION']}\nContext: "
149
+ + " ".join(sample["CONTEXTS"])
150
+ + "\nInstruction: Carefully read the question and the provided context. "
151
+ + "Answer with one word: 'yes', 'no', or 'maybe'. Answer: "
152
+ )
153
+ ]
122
154
 
123
155
  @override
124
156
  def load_target(self, index: int) -> torch.Tensor:
@@ -0,0 +1,151 @@
1
+ """Dataset class for loading pre-generated text predictions."""
2
+
3
+ import abc
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Generic, Literal
6
+
7
+ import pandas as pd
8
+ from typing_extensions import override
9
+
10
+ from eva.language.data.datasets.base import LanguageDataset
11
+ from eva.language.data.datasets.schemas import TransformsSchema
12
+ from eva.language.data.datasets.typings import PredictionSample, TargetType
13
+ from eva.language.data.messages import MessageSeries, UserMessage
14
+ from eva.language.utils.text import messages as message_utils
15
+
16
+
17
+ class TextPredictionDataset(
18
+ LanguageDataset[PredictionSample[TargetType]], abc.ABC, Generic[TargetType]
19
+ ):
20
+ """Dataset class for loading pre-generated text predictions."""
21
+
22
+ def __init__(
23
+ self,
24
+ path: str,
25
+ prediction_column: str = "prediction",
26
+ target_column: str = "target",
27
+ text_column: str | None = None,
28
+ metadata_columns: list[str] | None = None,
29
+ split: Literal["train", "val", "test"] | None = None,
30
+ transforms: TransformsSchema | None = None,
31
+ ):
32
+ """Initialize the dataset.
33
+
34
+ Args:
35
+ path: The path to the manifest file holding the predictions & targets.
36
+ prediction_column: The name of the prediction column.
37
+ target_column: The name of the label column.
38
+ text_column: The name of the column with the text inputs that were used
39
+ to generate the predictions. If the text column contains chat message
40
+ json format ([{"role": ..., "content": ...}]), it will be deserialized into
41
+ a list of Message objects. Otherwise, the content is interpreted as a
42
+ single user message.
43
+ metadata_columns: List of column names to include in metadata.
44
+ split: The dataset split to use (train, val, test). If not specified,
45
+ the entire dataset will be used.
46
+ transforms: The transforms to apply to the text and target when
47
+ loading the samples.
48
+ """
49
+ super().__init__()
50
+
51
+ self.path = path
52
+ self.prediction_column = prediction_column
53
+ self.target_column = target_column
54
+ self.text_column = text_column
55
+ self.metadata_columns = metadata_columns
56
+ self.split = split
57
+ self.transforms = transforms
58
+
59
+ self._data: pd.DataFrame
60
+
61
+ @override
62
+ def __len__(self) -> int:
63
+ return len(self._data)
64
+
65
+ @override
66
+ def __getitem__(self, index: int) -> PredictionSample[TargetType]:
67
+ item = PredictionSample(
68
+ prediction=self.load_prediction(index),
69
+ target=self.load_target(index),
70
+ text=self.load_text(index),
71
+ metadata=self.load_metadata(index) or {},
72
+ )
73
+ return self._apply_transforms(item)
74
+
75
+ @override
76
+ def configure(self) -> None:
77
+ extension = Path(self.path).suffix
78
+
79
+ match extension:
80
+ case ".jsonl":
81
+ self._data = pd.read_json(self.path, lines=True)
82
+ case ".csv":
83
+ self._data = pd.read_csv(self.path)
84
+ case ".parquet":
85
+ self._data = pd.read_parquet(self.path)
86
+ case _:
87
+ raise ValueError(f"Unsupported file extension: {extension}")
88
+
89
+ if self.split is not None:
90
+ self._data = self._data[self._data["split"] == self.split].reset_index(drop=True) # type: ignore
91
+
92
+ @override
93
+ def validate(self) -> None:
94
+ if self.prediction_column not in self._data.columns:
95
+ raise ValueError(f"Label column '{self.prediction_column}' not found.")
96
+ if self.target_column not in self._data.columns:
97
+ raise ValueError(f"Label column '{self.target_column}' not found.")
98
+ if self.metadata_columns:
99
+ missing_columns = set(self.metadata_columns) - set(self._data.columns)
100
+ if missing_columns:
101
+ raise ValueError(f"Metadata columns {missing_columns} not found.")
102
+
103
+ def load_prediction(self, index: int) -> TargetType:
104
+ """Returns the prediction for the given index."""
105
+ return self._data.iloc[index][self.prediction_column]
106
+
107
+ def load_target(self, index: int) -> TargetType:
108
+ """Returns the target for the given index."""
109
+ return self._data.iloc[index][self.target_column]
110
+
111
+ def load_text(self, index: int) -> MessageSeries | None:
112
+ """Returns the text for the given index."""
113
+ if self.text_column is None:
114
+ return None
115
+
116
+ text = self._data.iloc[index][self.text_column]
117
+
118
+ try:
119
+ return message_utils.deserialize(self._data.iloc[index][self.text_column])
120
+ except Exception:
121
+ return [UserMessage(content=text)]
122
+
123
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
124
+ """Returns the metadata for the given index."""
125
+ if self.metadata_columns is None:
126
+ return None
127
+
128
+ row = self._data.iloc[index]
129
+ return {col: row[col] for col in self.metadata_columns}
130
+
131
+ def _apply_transforms(
132
+ self, sample: PredictionSample[TargetType]
133
+ ) -> PredictionSample[TargetType]:
134
+ """Applies the dataset transforms to the prediction and target."""
135
+ if self.transforms:
136
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
137
+ prediction = (
138
+ self.transforms.prediction(sample.prediction)
139
+ if self.transforms.prediction
140
+ else sample.prediction
141
+ )
142
+ target = (
143
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
144
+ )
145
+ return PredictionSample(
146
+ prediction=prediction,
147
+ target=target,
148
+ text=text,
149
+ metadata=sample.metadata,
150
+ )
151
+ return sample
@@ -0,0 +1,18 @@
1
+ """Schema definitions for dataset classes."""
2
+
3
+ import dataclasses
4
+ from typing import Callable
5
+
6
+
7
+ @dataclasses.dataclass(frozen=True)
8
+ class TransformsSchema:
9
+ """Schema for dataset transforms."""
10
+
11
+ text: Callable | None = None
12
+ """Text transformation"""
13
+
14
+ target: Callable | None = None
15
+ """Target transformation"""
16
+
17
+ prediction: Callable | None = None
18
+ """Prediction transformation"""
@@ -0,0 +1,92 @@
1
+ """Base classes for text-image datasets."""
2
+
3
+ import abc
4
+ from typing import Any, Dict, Generic
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.language.data.datasets.base import LanguageDataset
9
+ from eva.language.data.datasets.schemas import TransformsSchema
10
+ from eva.language.data.datasets.typings import TargetType, TextSample
11
+ from eva.language.data.messages import MessageSeries
12
+
13
+
14
+ class TextDataset(LanguageDataset[TextSample[TargetType]], abc.ABC, Generic[TargetType]):
15
+ """Base dataset class for text-based tasks."""
16
+
17
+ def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
18
+ """Initializes the dataset.
19
+
20
+ Args:
21
+ *args: Positional arguments for the base class.
22
+ transforms: The transforms to apply to the text and target when
23
+ loading the samples.
24
+ **kwargs: Keyword arguments for the base class.
25
+ """
26
+ super().__init__(*args, **kwargs)
27
+
28
+ self.transforms = transforms
29
+
30
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
31
+ """Returns the dataset metadata.
32
+
33
+ Args:
34
+ index: The index of the data sample.
35
+
36
+ Returns:
37
+ The sample metadata.
38
+ """
39
+
40
+ @abc.abstractmethod
41
+ def load_text(self, index: int) -> MessageSeries:
42
+ """Returns the text content.
43
+
44
+ Args:
45
+ index: The index of the data sample.
46
+
47
+ Returns:
48
+ The text content.
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @abc.abstractmethod
53
+ def load_target(self, index: int) -> TargetType:
54
+ """Returns the target label.
55
+
56
+ Args:
57
+ index: The index of the data sample.
58
+
59
+ Returns:
60
+ The target label.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ @override
65
+ def __getitem__(self, index: int) -> TextSample[TargetType]:
66
+ item = TextSample(
67
+ text=self.load_text(index),
68
+ target=self.load_target(index),
69
+ metadata=self.load_metadata(index) or {},
70
+ )
71
+ return self._apply_transforms(item)
72
+
73
+ def _apply_transforms(self, sample: TextSample[TargetType]) -> TextSample[TargetType]:
74
+ """Applies the dataset transforms to the text and target.
75
+
76
+ Args:
77
+ sample: The text sample..
78
+
79
+ Returns:
80
+ The transformed sample.
81
+ """
82
+ if self.transforms:
83
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
84
+ target = (
85
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
86
+ )
87
+ return TextSample(
88
+ text=text,
89
+ target=target,
90
+ metadata=sample.metadata,
91
+ )
92
+ return sample
@@ -0,0 +1,39 @@
1
+ """Typings for multimodal datasets."""
2
+
3
+ from typing import Any, Generic, TypeVar
4
+
5
+ from typing_extensions import NamedTuple
6
+
7
+ from eva.language.data.messages import MessageSeries
8
+
9
+ TargetType = TypeVar("TargetType")
10
+ """The target data type."""
11
+
12
+
13
+ class TextSample(NamedTuple, Generic[TargetType]):
14
+ """Text sample with target and metadata."""
15
+
16
+ text: MessageSeries
17
+ """One or multiple conversation messages."""
18
+
19
+ target: TargetType | None
20
+ """Target data."""
21
+
22
+ metadata: dict[str, Any] | None
23
+ """Additional metadata."""
24
+
25
+
26
+ class PredictionSample(NamedTuple, Generic[TargetType]):
27
+ """Text sample with target and metadata."""
28
+
29
+ prediction: TargetType
30
+ """Prediction data."""
31
+
32
+ target: TargetType
33
+ """Target data."""
34
+
35
+ text: MessageSeries | None
36
+ """Conversation messages that were used as input."""
37
+
38
+ metadata: dict[str, Any] | None
39
+ """Additional metadata."""
@@ -0,0 +1,60 @@
1
+ """Types and classes for conversation messages in a multimodal context."""
2
+
3
+ import dataclasses
4
+ import enum
5
+ from typing import Any, Dict, List
6
+
7
+
8
+ class Role(str, enum.Enum):
9
+ """Roles for messages in a conversation."""
10
+
11
+ USER = "user"
12
+ ASSISTANT = "assistant"
13
+ SYSTEM = "system"
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Message:
18
+ """Base class for a message in a conversation."""
19
+
20
+ content: str
21
+ role: str
22
+
23
+ def to_dict(self) -> Dict[str, Any]:
24
+ """Convert the message to a dictionary."""
25
+ return dataclasses.asdict(self)
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class UserMessage(Message):
30
+ """User message in a conversation."""
31
+
32
+ role: str = Role.USER
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class AssistantMessage(Message):
37
+ """Assistant message in a conversation."""
38
+
39
+ role: str = Role.ASSISTANT
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class SystemMessage(Message):
44
+ """System message in a conversation."""
45
+
46
+ role: str = Role.SYSTEM
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class ModelSystemMessage(SystemMessage):
51
+ """System message for model-specific instructions."""
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class TaskSystemMessage(SystemMessage):
56
+ """System message for task-specific instructions."""
57
+
58
+
59
+ MessageSeries = List[Message]
60
+ """A series of conversation messages, can contain a mix of system, user, and AI messages."""
@@ -1,25 +1,29 @@
1
1
  """Language Models API."""
2
2
 
3
- from eva.language.models import modules, wrappers
4
- from eva.language.models.modules import TextModule
5
- from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel
3
+ from eva.language.models import modules, networks, wrappers
4
+ from eva.language.models.modules import LanguageModule, OfflineLanguageModule
5
+ from eva.language.models.wrappers import HuggingFaceModel, LiteLLMModel
6
6
 
7
7
  try:
8
- from eva.language.models.wrappers import VLLMTextModel
8
+ from eva.language.models.wrappers import VllmModel
9
9
 
10
10
  __all__ = [
11
11
  "modules",
12
12
  "wrappers",
13
- "TextModule",
14
- "HuggingFaceTextModel",
15
- "LiteLLMTextModel",
16
- "VLLMTextModel",
13
+ "networks",
14
+ "HuggingFaceModel",
15
+ "LiteLLMModel",
16
+ "VllmModel",
17
+ "LanguageModule",
18
+ "OfflineLanguageModule",
17
19
  ]
18
20
  except ImportError:
19
21
  __all__ = [
20
22
  "modules",
21
23
  "wrappers",
22
- "TextModule",
23
- "HuggingFaceTextModel",
24
- "LiteLLMTextModel",
24
+ "networks",
25
+ "HuggingFaceModel",
26
+ "LiteLLMModel",
27
+ "LanguageModule",
28
+ "OfflineLanguageModule",
25
29
  ]
@@ -1,5 +1,5 @@
1
1
  """Language Networks API."""
2
2
 
3
- from eva.language.models.modules.text import TextModule
3
+ from eva.language.models.modules.language import LanguageModule, OfflineLanguageModule
4
4
 
5
- __all__ = ["TextModule"]
5
+ __all__ = ["LanguageModule", "OfflineLanguageModule"]
@@ -0,0 +1,93 @@
1
+ """Model module for language models."""
2
+
3
+ from typing import Any, List
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.metrics import structs as metrics_lib
10
+ from eva.core.models.modules import module
11
+ from eva.core.models.modules.utils import batch_postprocess
12
+ from eva.language.models.typings import PredictionBatch, TextBatch
13
+
14
+
15
+ class LanguageModule(module.ModelModule):
16
+ """Model module for language tasks."""
17
+
18
+ def __init__(
19
+ self,
20
+ model: nn.Module,
21
+ metrics: metrics_lib.MetricsSchema | None = None,
22
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
23
+ ) -> None:
24
+ """Initializes the text inference module.
25
+
26
+ Args:
27
+ model: Model instance to use for forward pass.
28
+ metrics: Metrics schema for evaluation.
29
+ postprocess: A helper function to post-process model outputs before evaluation.
30
+ """
31
+ super().__init__(metrics=metrics, postprocess=postprocess)
32
+
33
+ self.model = model
34
+
35
+ @override
36
+ def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) -> List[str]:
37
+ return self.model(batch)
38
+
39
+ @override
40
+ def validation_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
41
+ return self._batch_step(batch)
42
+
43
+ @override
44
+ def test_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
45
+ return self._batch_step(batch)
46
+
47
+ def _batch_step(self, batch: TextBatch) -> STEP_OUTPUT:
48
+ text, targets, metadata = TextBatch(*batch)
49
+ predictions = self.forward(batch)
50
+ return {
51
+ "inputs": text,
52
+ "predictions": predictions,
53
+ "targets": targets,
54
+ "metadata": metadata,
55
+ }
56
+
57
+
58
+ class OfflineLanguageModule(module.ModelModule):
59
+ """Model module for offline language tasks."""
60
+
61
+ def __init__(
62
+ self,
63
+ metrics: metrics_lib.MetricsSchema | None = None,
64
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
65
+ ) -> None:
66
+ """Initializes the text inference module.
67
+
68
+ Args:
69
+ metrics: Metrics schema for evaluation.
70
+ postprocess: A helper function to post-process model outputs before evaluation.
71
+ """
72
+ super().__init__(metrics=metrics, postprocess=postprocess)
73
+
74
+ @override
75
+ def forward(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> PredictionBatch:
76
+ return batch
77
+
78
+ @override
79
+ def validation_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
80
+ return self._batch_step(batch)
81
+
82
+ @override
83
+ def test_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
84
+ return self._batch_step(batch)
85
+
86
+ def _batch_step(self, batch: PredictionBatch) -> STEP_OUTPUT:
87
+ predictions, targets, text, metadata = PredictionBatch(*batch)
88
+ return {
89
+ "inputs": text,
90
+ "predictions": predictions,
91
+ "targets": targets,
92
+ "metadata": metadata,
93
+ }
@@ -0,0 +1,12 @@
1
+ """Language networks API."""
2
+
3
+ from eva.language.models.networks.alibaba import Qwen205BInstruct
4
+ from eva.language.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
5
+ from eva.language.models.networks.registry import model_registry
6
+
7
+ __all__ = [
8
+ "Claude35Sonnet20240620",
9
+ "Claude37Sonnet20250219",
10
+ "Qwen205BInstruct",
11
+ "model_registry",
12
+ ]
@@ -0,0 +1,26 @@
1
+ """Models from Alibaba."""
2
+
3
+ import torch
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ @model_registry.register("alibaba/qwen2-0-5b-instruct")
10
+ class Qwen205BInstruct(wrappers.HuggingFaceModel):
11
+ """Qwen2 0.5B Instruct model."""
12
+
13
+ def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
14
+ """Initialize the model."""
15
+ super().__init__(
16
+ model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
17
+ model_kwargs={
18
+ "torch_dtype": torch.bfloat16,
19
+ "cache_dir": cache_dir,
20
+ },
21
+ generation_kwargs={
22
+ "max_new_tokens": 512,
23
+ },
24
+ system_prompt=system_prompt,
25
+ chat_mode=True,
26
+ )
@@ -0,0 +1,11 @@
1
+ """Multimodal API networks."""
2
+
3
+ from eva.language.models.networks.api.anthropic import (
4
+ Claude35Sonnet20240620,
5
+ Claude37Sonnet20250219,
6
+ )
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ ]
@@ -0,0 +1,34 @@
1
+ """Models from Anthropic."""
2
+
3
+ import os
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ class _Claude(wrappers.LiteLLMModel):
10
+ """Base class for Claude models."""
11
+
12
+ def __init__(self, model_name: str, system_prompt: str | None = None):
13
+ if not os.getenv("ANTHROPIC_API_KEY"):
14
+ raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
15
+
16
+ super().__init__(model_name=model_name, system_prompt=system_prompt)
17
+
18
+
19
+ @model_registry.register("anthropic/claude-3-5-sonnet-20240620")
20
+ class Claude35Sonnet20240620(_Claude):
21
+ """Claude 3.5 Sonnet (June 2024) model."""
22
+
23
+ def __init__(self, system_prompt: str | None = None):
24
+ """Initialize the model."""
25
+ super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
26
+
27
+
28
+ @model_registry.register("anthropic/claude-3-7-sonnet-20250219")
29
+ class Claude37Sonnet20250219(_Claude):
30
+ """Claude 3.7 Sonnet (February 2025) model."""
31
+
32
+ def __init__(self, system_prompt: str | None = None):
33
+ """Initialize the model."""
34
+ super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
@@ -0,0 +1,5 @@
1
+ """Language Model Registry."""
2
+
3
+ from eva.core.utils.registry import Registry
4
+
5
+ model_registry = Registry()