kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Base classes for text-image datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Dict, Generic
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.data.datasets.base import LanguageDataset
|
|
9
|
+
from eva.language.data.datasets.schemas import TransformsSchema
|
|
10
|
+
from eva.language.data.datasets.typings import TargetType, TextSample
|
|
11
|
+
from eva.language.data.messages import MessageSeries
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextDataset(LanguageDataset[TextSample[TargetType]], abc.ABC, Generic[TargetType]):
|
|
15
|
+
"""Base dataset class for text-based tasks."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
|
|
18
|
+
"""Initializes the dataset.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
*args: Positional arguments for the base class.
|
|
22
|
+
transforms: The transforms to apply to the text and target when
|
|
23
|
+
loading the samples.
|
|
24
|
+
**kwargs: Keyword arguments for the base class.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
self.transforms = transforms
|
|
29
|
+
|
|
30
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
31
|
+
"""Returns the dataset metadata.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
index: The index of the data sample.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The sample metadata.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
42
|
+
"""Returns the text content.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
index: The index of the data sample.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The text content.
|
|
49
|
+
"""
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def load_target(self, index: int) -> TargetType:
|
|
54
|
+
"""Returns the target label.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
index: The index of the data sample.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The target label.
|
|
61
|
+
"""
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def __getitem__(self, index: int) -> TextSample[TargetType]:
|
|
66
|
+
item = TextSample(
|
|
67
|
+
text=self.load_text(index),
|
|
68
|
+
target=self.load_target(index),
|
|
69
|
+
metadata=self.load_metadata(index) or {},
|
|
70
|
+
)
|
|
71
|
+
return self._apply_transforms(item)
|
|
72
|
+
|
|
73
|
+
def _apply_transforms(self, sample: TextSample[TargetType]) -> TextSample[TargetType]:
|
|
74
|
+
"""Applies the dataset transforms to the text and target.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
sample: The text sample..
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The transformed sample.
|
|
81
|
+
"""
|
|
82
|
+
if self.transforms:
|
|
83
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
84
|
+
target = (
|
|
85
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
86
|
+
)
|
|
87
|
+
return TextSample(
|
|
88
|
+
text=text,
|
|
89
|
+
target=target,
|
|
90
|
+
metadata=sample.metadata,
|
|
91
|
+
)
|
|
92
|
+
return sample
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Typings for multimodal datasets."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from typing_extensions import NamedTuple
|
|
6
|
+
|
|
7
|
+
from eva.language.data.messages import MessageSeries
|
|
8
|
+
|
|
9
|
+
TargetType = TypeVar("TargetType")
|
|
10
|
+
"""The target data type."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextSample(NamedTuple, Generic[TargetType]):
|
|
14
|
+
"""Text sample with target and metadata."""
|
|
15
|
+
|
|
16
|
+
text: MessageSeries
|
|
17
|
+
"""One or multiple conversation messages."""
|
|
18
|
+
|
|
19
|
+
target: TargetType | None
|
|
20
|
+
"""Target data."""
|
|
21
|
+
|
|
22
|
+
metadata: dict[str, Any] | None
|
|
23
|
+
"""Additional metadata."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PredictionSample(NamedTuple, Generic[TargetType]):
|
|
27
|
+
"""Text sample with target and metadata."""
|
|
28
|
+
|
|
29
|
+
prediction: TargetType
|
|
30
|
+
"""Prediction data."""
|
|
31
|
+
|
|
32
|
+
target: TargetType
|
|
33
|
+
"""Target data."""
|
|
34
|
+
|
|
35
|
+
text: MessageSeries | None
|
|
36
|
+
"""Conversation messages that were used as input."""
|
|
37
|
+
|
|
38
|
+
metadata: dict[str, Any] | None
|
|
39
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Types and classes for conversation messages in a multimodal context."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import enum
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Role(str, enum.Enum):
|
|
9
|
+
"""Roles for messages in a conversation."""
|
|
10
|
+
|
|
11
|
+
USER = "user"
|
|
12
|
+
ASSISTANT = "assistant"
|
|
13
|
+
SYSTEM = "system"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class Message:
|
|
18
|
+
"""Base class for a message in a conversation."""
|
|
19
|
+
|
|
20
|
+
content: str
|
|
21
|
+
role: str
|
|
22
|
+
|
|
23
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
24
|
+
"""Convert the message to a dictionary."""
|
|
25
|
+
return dataclasses.asdict(self)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclasses.dataclass
|
|
29
|
+
class UserMessage(Message):
|
|
30
|
+
"""User message in a conversation."""
|
|
31
|
+
|
|
32
|
+
role: str = Role.USER
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclasses.dataclass
|
|
36
|
+
class AssistantMessage(Message):
|
|
37
|
+
"""Assistant message in a conversation."""
|
|
38
|
+
|
|
39
|
+
role: str = Role.ASSISTANT
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclasses.dataclass
|
|
43
|
+
class SystemMessage(Message):
|
|
44
|
+
"""System message in a conversation."""
|
|
45
|
+
|
|
46
|
+
role: str = Role.SYSTEM
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass
|
|
50
|
+
class ModelSystemMessage(SystemMessage):
|
|
51
|
+
"""System message for model-specific instructions."""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class TaskSystemMessage(SystemMessage):
|
|
56
|
+
"""System message for task-specific instructions."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
MessageSeries = List[Message]
|
|
60
|
+
"""A series of conversation messages, can contain a mix of system, user, and AI messages."""
|
eva/language/models/__init__.py
CHANGED
|
@@ -1,25 +1,29 @@
|
|
|
1
1
|
"""Language Models API."""
|
|
2
2
|
|
|
3
|
-
from eva.language.models import modules, wrappers
|
|
4
|
-
from eva.language.models.modules import
|
|
5
|
-
from eva.language.models.wrappers import
|
|
3
|
+
from eva.language.models import modules, networks, wrappers
|
|
4
|
+
from eva.language.models.modules import LanguageModule, OfflineLanguageModule
|
|
5
|
+
from eva.language.models.wrappers import HuggingFaceModel, LiteLLMModel
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
|
-
from eva.language.models.wrappers import
|
|
8
|
+
from eva.language.models.wrappers import VllmModel
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"modules",
|
|
12
12
|
"wrappers",
|
|
13
|
-
"
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
13
|
+
"networks",
|
|
14
|
+
"HuggingFaceModel",
|
|
15
|
+
"LiteLLMModel",
|
|
16
|
+
"VllmModel",
|
|
17
|
+
"LanguageModule",
|
|
18
|
+
"OfflineLanguageModule",
|
|
17
19
|
]
|
|
18
20
|
except ImportError:
|
|
19
21
|
__all__ = [
|
|
20
22
|
"modules",
|
|
21
23
|
"wrappers",
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
"
|
|
24
|
+
"networks",
|
|
25
|
+
"HuggingFaceModel",
|
|
26
|
+
"LiteLLMModel",
|
|
27
|
+
"LanguageModule",
|
|
28
|
+
"OfflineLanguageModule",
|
|
25
29
|
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Model module for language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.metrics import structs as metrics_lib
|
|
10
|
+
from eva.core.models.modules import module
|
|
11
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.language.models.typings import ModelOutput, PredictionBatch, TextBatch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LanguageModule(module.ModelModule):
|
|
16
|
+
"""Model module for language tasks."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
22
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the text inference module.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model instance to use for forward pass.
|
|
28
|
+
metrics: Metrics schema for evaluation.
|
|
29
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
32
|
+
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) -> ModelOutput:
|
|
37
|
+
return self.model(batch)
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def validation_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
41
|
+
return self._batch_step(batch)
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def test_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
45
|
+
return self._batch_step(batch)
|
|
46
|
+
|
|
47
|
+
def _batch_step(self, batch: TextBatch) -> STEP_OUTPUT:
|
|
48
|
+
text, targets, metadata = TextBatch(*batch)
|
|
49
|
+
output = self.forward(batch)
|
|
50
|
+
|
|
51
|
+
return {
|
|
52
|
+
"inputs": text,
|
|
53
|
+
"predictions": output.pop("generated_text"), # type: ignore
|
|
54
|
+
"targets": targets,
|
|
55
|
+
"metadata": metadata,
|
|
56
|
+
} | output
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class OfflineLanguageModule(module.ModelModule):
|
|
60
|
+
"""Model module for offline language tasks."""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
65
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Initializes the text inference module.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
metrics: Metrics schema for evaluation.
|
|
71
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
72
|
+
"""
|
|
73
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def forward(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> PredictionBatch:
|
|
77
|
+
return batch
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def validation_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
81
|
+
return self._batch_step(batch)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
def test_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
85
|
+
return self._batch_step(batch)
|
|
86
|
+
|
|
87
|
+
def _batch_step(self, batch: PredictionBatch) -> STEP_OUTPUT:
|
|
88
|
+
predictions, targets, text, metadata = PredictionBatch(*batch)
|
|
89
|
+
return {
|
|
90
|
+
"inputs": text,
|
|
91
|
+
"predictions": predictions,
|
|
92
|
+
"targets": targets,
|
|
93
|
+
"metadata": metadata,
|
|
94
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Language networks API."""
|
|
2
|
+
|
|
3
|
+
from eva.language.models.networks.alibaba import Qwen205BInstruct
|
|
4
|
+
from eva.language.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
|
|
5
|
+
from eva.language.models.networks.registry import model_registry
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Claude35Sonnet20240620",
|
|
9
|
+
"Claude37Sonnet20250219",
|
|
10
|
+
"Qwen205BInstruct",
|
|
11
|
+
"model_registry",
|
|
12
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-0-5b-instruct")
|
|
10
|
+
class Qwen205BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2 0.5B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
|
|
14
|
+
"""Initialize the model."""
|
|
15
|
+
super().__init__(
|
|
16
|
+
model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
|
|
17
|
+
model_kwargs={
|
|
18
|
+
"torch_dtype": torch.bfloat16,
|
|
19
|
+
"cache_dir": cache_dir,
|
|
20
|
+
},
|
|
21
|
+
generation_kwargs={
|
|
22
|
+
"max_new_tokens": 512,
|
|
23
|
+
},
|
|
24
|
+
system_prompt=system_prompt,
|
|
25
|
+
chat_mode=True,
|
|
26
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.language.models import wrappers
|
|
6
|
+
from eva.language.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Type definitions for language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Generic, List, TypedDict, TypeVar
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from typing_extensions import NamedTuple, NotRequired
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextBatch(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: List[MessageSeries]
|
|
18
|
+
"""Text content."""
|
|
19
|
+
|
|
20
|
+
target: TargetType | None
|
|
21
|
+
"""Target data."""
|
|
22
|
+
|
|
23
|
+
metadata: Dict[str, Any] | None
|
|
24
|
+
"""Additional metadata."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PredictionBatch(NamedTuple, Generic[TargetType]):
|
|
28
|
+
"""Text sample with target and metadata."""
|
|
29
|
+
|
|
30
|
+
prediction: TargetType
|
|
31
|
+
"""Prediction data."""
|
|
32
|
+
|
|
33
|
+
target: TargetType
|
|
34
|
+
"""Target data."""
|
|
35
|
+
|
|
36
|
+
text: List[MessageSeries] | None
|
|
37
|
+
"""Conversation messages that were used as input."""
|
|
38
|
+
|
|
39
|
+
metadata: Dict[str, Any] | None
|
|
40
|
+
"""Additional metadata."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelOutput(TypedDict):
|
|
44
|
+
"""The output batch produced by the model forward pass."""
|
|
45
|
+
|
|
46
|
+
generated_text: List[str]
|
|
47
|
+
"""The text generated by the model."""
|
|
48
|
+
|
|
49
|
+
input_ids: NotRequired[torch.Tensor | None]
|
|
50
|
+
"""The token ids of the input text."""
|
|
51
|
+
|
|
52
|
+
output_ids: NotRequired[torch.Tensor | None]
|
|
53
|
+
"""The token ids of the model output (usually containing both input and prediction)."""
|
|
54
|
+
|
|
55
|
+
attention_mask: NotRequired[torch.Tensor | None]
|
|
56
|
+
"""The attention mask for the input tokens."""
|
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
"""Language Model Wrappers API."""
|
|
2
2
|
|
|
3
|
-
from eva.language.models.wrappers.
|
|
4
|
-
from eva.language.models.wrappers.
|
|
3
|
+
from eva.language.models.wrappers.base import LanguageModel
|
|
4
|
+
from eva.language.models.wrappers.from_registry import ModelFromRegistry
|
|
5
|
+
from eva.language.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.language.models.wrappers.litellm import LiteLLMModel
|
|
5
7
|
|
|
6
8
|
try:
|
|
7
|
-
from eva.language.models.wrappers.vllm import
|
|
9
|
+
from eva.language.models.wrappers.vllm import VllmModel
|
|
8
10
|
|
|
9
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"LanguageModel",
|
|
13
|
+
"HuggingFaceModel",
|
|
14
|
+
"LiteLLMModel",
|
|
15
|
+
"VllmModel",
|
|
16
|
+
"ModelFromRegistry",
|
|
17
|
+
]
|
|
10
18
|
except ImportError:
|
|
11
|
-
__all__ = ["
|
|
19
|
+
__all__ = ["LanguageModel", "HuggingFaceModel", "LiteLLMModel", "ModelFromRegistry"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Base class for language model wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
+
from eva.language.models.typings import ModelOutput, TextBatch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LanguageModel(base.BaseModel[TextBatch, ModelOutput]):
|
|
14
|
+
"""Base class for language models.
|
|
15
|
+
|
|
16
|
+
Classes that inherit from this should implement the following methods:
|
|
17
|
+
- `load_model`: Loads & instantiates the model.
|
|
18
|
+
- `model_forward`: Implements the forward pass of the model. For API models,
|
|
19
|
+
this can be an API call.
|
|
20
|
+
- `format_inputs`: Preprocesses and converts the input batch into the format
|
|
21
|
+
expected by the `model_forward` method.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self, system_prompt: str | None, output_transforms: Callable | None = None
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Creates a new model instance.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
system_prompt: The system prompt to use for the model (optional).
|
|
31
|
+
output_transforms: Optional transforms to apply to the output of
|
|
32
|
+
the model's forward pass.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(transforms=output_transforms)
|
|
35
|
+
|
|
36
|
+
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def forward(self, batch: TextBatch) -> ModelOutput:
|
|
40
|
+
"""Forward pass of the model."""
|
|
41
|
+
inputs = self.format_inputs(batch)
|
|
42
|
+
return super().forward(inputs)
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def format_inputs(self, batch: TextBatch) -> Any:
|
|
46
|
+
"""Converts the inputs into the format expected by the model."""
|
|
47
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Vision backbone helper class."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.core.utils import factory
|
|
10
|
+
from eva.language.models.networks.registry import model_registry
|
|
11
|
+
from eva.language.models.typings import TextBatch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFromRegistry(base.BaseModel[TextBatch, List[str]]):
|
|
15
|
+
"""Wrapper class for vision backbone models.
|
|
16
|
+
|
|
17
|
+
This class can be used by load backbones available in eva's
|
|
18
|
+
model registry by name. New backbones can be registered by using
|
|
19
|
+
the `@backbone_registry.register(model_name)` decorator.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model_name: str,
|
|
25
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
+
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
27
|
+
transforms: Callable | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes the model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model_name: The name of the model to load.
|
|
33
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
34
|
+
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
35
|
+
transforms: The transforms to apply to the output tensor
|
|
36
|
+
produced by the model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(transforms=transforms)
|
|
39
|
+
|
|
40
|
+
self._model_name = model_name
|
|
41
|
+
self._model_kwargs = model_kwargs or {}
|
|
42
|
+
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
43
|
+
|
|
44
|
+
self.model = self.load_model()
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|
|
49
|
+
|
|
50
|
+
return factory.ModuleFactory(
|
|
51
|
+
registry=model_registry,
|
|
52
|
+
name=self._model_name,
|
|
53
|
+
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
54
|
+
)
|