kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +15 -6
- eva/core/callbacks/writers/embeddings/base.py +44 -10
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/samplers/classification/balanced.py +24 -12
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/utils/wandb.py +4 -1
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +13 -1
- eva/core/utils/__init__.py +2 -1
- eva/core/utils/distributed.py +12 -0
- eva/core/utils/paths.py +14 -0
- eva/core/utils/requirements.py +52 -6
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +201 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +94 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +56 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +57 -11
- eva/language/models/wrappers/litellm.py +91 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +56 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +40 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +48 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +48 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +193 -0
- eva/multimodal/models/wrappers/litellm.py +58 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/batch/__init__.py +5 -0
- eva/multimodal/utils/batch/unpack.py +11 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/breakhis.py +5 -8
- eva/vision/data/datasets/classification/panda.py +12 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/datasets/segmentation/btcv.py +1 -1
- eva/vision/data/datasets/segmentation/consep.py +1 -1
- eva/vision/data/datasets/segmentation/lits17.py +1 -1
- eva/vision/data/datasets/segmentation/monusac.py +15 -6
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/base/__init__.py +2 -1
- eva/vision/data/transforms/base/monai.py +2 -2
- eva/vision/data/transforms/base/torchvision.py +33 -0
- eva/vision/data/transforms/common/squeeze.py +6 -3
- eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
- eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/flip.py +8 -7
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +63 -0
- eva/vision/data/transforms/spatial/rotate.py +8 -7
- eva/vision/data/transforms/spatial/spacing.py +7 -6
- eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
- eva/vision/models/networks/backbones/universal/vit.py +24 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Collate functions for text-image data."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from torch.utils.data._utils.collate import default_collate
|
|
6
|
+
|
|
7
|
+
from eva.multimodal.data.datasets.typings import TextImageSample
|
|
8
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def text_image_collate(batch: List[TextImageSample]) -> TextImageBatch:
|
|
12
|
+
"""Collate function for text-image batches."""
|
|
13
|
+
texts, images, targets, metadata = zip(*batch, strict=False)
|
|
14
|
+
|
|
15
|
+
first_sample = batch[0]
|
|
16
|
+
metadata = None
|
|
17
|
+
if first_sample.metadata is not None:
|
|
18
|
+
metadata = {
|
|
19
|
+
k: [sample.metadata[k] for sample in batch if sample.metadata]
|
|
20
|
+
for k in first_sample.metadata.keys()
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
return TextImageBatch(
|
|
24
|
+
text=list(texts),
|
|
25
|
+
image=list(images),
|
|
26
|
+
target=default_collate(targets) if targets[0] is not None else None,
|
|
27
|
+
metadata=metadata,
|
|
28
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Multimodal Dataset base class."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from eva.core.data.datasets import base
|
|
7
|
+
|
|
8
|
+
DataSample = TypeVar("DataSample")
|
|
9
|
+
"""The data sample type."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultimodalDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
13
|
+
"""Base dataset class for multimodal tasks."""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""PatchCamelyon dataset with text prompts for multimodal classification."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Literal
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries, UserMessage
|
|
9
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
10
|
+
from eva.multimodal.data.datasets.text_image import TextImageDataset
|
|
11
|
+
from eva.vision.data import datasets as vision_datasets
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PatchCamelyon(TextImageDataset[int], vision_datasets.PatchCamelyon):
|
|
15
|
+
"""PatchCamelyon image classification using a multiple choice text prompt."""
|
|
16
|
+
|
|
17
|
+
_default_prompt = (
|
|
18
|
+
"You are a pathology expert helping pathologists to analyze images of tissue samples.\n"
|
|
19
|
+
"Question: Does this image show metastatic breast tissue?\n"
|
|
20
|
+
"Options: A: no, B: yes\n"
|
|
21
|
+
"Only answer with a single letter without further explanation. "
|
|
22
|
+
"Please always provide an answer, even if you are not sure.\n"
|
|
23
|
+
"Answer: "
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
root: str,
|
|
29
|
+
split: Literal["train", "val", "test"],
|
|
30
|
+
download: bool = False,
|
|
31
|
+
transforms: TransformsSchema | None = None,
|
|
32
|
+
prompt: str | None = None,
|
|
33
|
+
max_samples: int | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initializes the dataset.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
root: The path to the dataset root. This path should contain
|
|
39
|
+
the uncompressed h5 files and the metadata.
|
|
40
|
+
split: The dataset split for training, validation, or testing.
|
|
41
|
+
download: Whether to download the data for the specified split.
|
|
42
|
+
Note that the download will be executed only by additionally
|
|
43
|
+
calling the :meth:`prepare_data` method.
|
|
44
|
+
transforms: A function/transform which returns a transformed
|
|
45
|
+
version of the raw data samples.
|
|
46
|
+
prompt: The text prompt to use for classification (multple choice).
|
|
47
|
+
max_samples: Maximum number of samples to use. If None, use all samples.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(root=root, split=split, download=download, transforms=transforms)
|
|
50
|
+
|
|
51
|
+
self.max_samples = max_samples
|
|
52
|
+
self.prompt = prompt or self._default_prompt
|
|
53
|
+
|
|
54
|
+
if self.max_samples is not None:
|
|
55
|
+
self._expected_length = {split: max_samples}
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
@override
|
|
59
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
60
|
+
return {"A": 0, "B": 1}
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def __len__(self) -> int:
|
|
64
|
+
return self.max_samples or self._fetch_dataset_length()
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def load_text(self, index: int) -> MessageSeries:
|
|
68
|
+
return [UserMessage(content=self.prompt)]
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
72
|
+
return vision_datasets.PatchCamelyon.load_data(self, index)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def load_target(self, index: int) -> int:
|
|
76
|
+
return int(vision_datasets.PatchCamelyon.load_target(self, index).item())
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
80
|
+
return vision_datasets.PatchCamelyon.load_metadata(self, index)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Schema definitions for dataset classes."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from eva.language.data.datasets import schemas as language_schemas
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclasses.dataclass(frozen=True)
|
|
10
|
+
class TransformsSchema(language_schemas.TransformsSchema):
|
|
11
|
+
"""Schema for dataset transforms."""
|
|
12
|
+
|
|
13
|
+
image: Callable | None = None
|
|
14
|
+
"""Image transformation"""
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Base classes for text-image datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Generic
|
|
5
|
+
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.language.data.datasets.text import TextDataset
|
|
10
|
+
from eva.multimodal.data.datasets.base import MultimodalDataset
|
|
11
|
+
from eva.multimodal.data.datasets.schemas import TransformsSchema
|
|
12
|
+
from eva.multimodal.data.datasets.typings import TargetType, TextImageSample
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TextImageDataset(
|
|
16
|
+
MultimodalDataset[TextImageSample[TargetType]], TextDataset, abc.ABC, Generic[TargetType]
|
|
17
|
+
):
|
|
18
|
+
"""Base dataset class for text-image tasks."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
|
|
21
|
+
"""Initializes the dataset.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
*args: Positional arguments for the base class.
|
|
25
|
+
transforms: The transforms to apply to the text, image and target when
|
|
26
|
+
loading the samples.
|
|
27
|
+
**kwargs: Keyword arguments for the base class.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__(*args, **kwargs)
|
|
30
|
+
|
|
31
|
+
self.transforms = transforms
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
35
|
+
"""Returns the image content.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
index: The index of the data sample.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The image content.
|
|
42
|
+
"""
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def __getitem__(self, index: int) -> TextImageSample[TargetType]:
|
|
47
|
+
item = TextImageSample(
|
|
48
|
+
text=self.load_text(index),
|
|
49
|
+
image=self.load_image(index),
|
|
50
|
+
target=self.load_target(index),
|
|
51
|
+
metadata=self.load_metadata(index) or {},
|
|
52
|
+
)
|
|
53
|
+
return self._apply_transforms(item)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def _apply_transforms(self, sample: TextImageSample[TargetType]) -> TextImageSample[TargetType]:
|
|
57
|
+
"""Applies the dataset transforms to the text, image and target.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
sample: The sample containing text, image, target and metadata.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The transformed sample.
|
|
64
|
+
"""
|
|
65
|
+
if self.transforms:
|
|
66
|
+
text = self.transforms.text(sample.text) if self.transforms.text else sample.text
|
|
67
|
+
image = self.transforms.image(sample.image) if self.transforms.image else sample.image
|
|
68
|
+
target = (
|
|
69
|
+
self.transforms.target(sample.target) if self.transforms.target else sample.target
|
|
70
|
+
)
|
|
71
|
+
return TextImageSample(
|
|
72
|
+
text=text,
|
|
73
|
+
image=image,
|
|
74
|
+
target=target,
|
|
75
|
+
metadata=sample.metadata,
|
|
76
|
+
)
|
|
77
|
+
return sample
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Typings for multimodal datasets."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import NamedTuple
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextImageSample(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text and image sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: MessageSeries
|
|
18
|
+
"""One or multiple conversation messages."""
|
|
19
|
+
|
|
20
|
+
image: tv_tensors.Image
|
|
21
|
+
"""Image tensor."""
|
|
22
|
+
|
|
23
|
+
target: TargetType | None
|
|
24
|
+
"""Target data."""
|
|
25
|
+
|
|
26
|
+
metadata: dict[str, Any] | None
|
|
27
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Model module for vision-language models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.metrics import structs as metrics_lib
|
|
10
|
+
from eva.core.models.modules import module
|
|
11
|
+
from eva.core.models.modules.utils import batch_postprocess
|
|
12
|
+
from eva.language.models.typings import ModelOutput
|
|
13
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class VisionLanguageModule(module.ModelModule):
|
|
17
|
+
"""Model module for vision-language tasks."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: nn.Module,
|
|
22
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
23
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Initializes the text inference module.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: Model instance to use for forward pass.
|
|
29
|
+
metrics: Metrics schema for evaluation.
|
|
30
|
+
postprocess: A helper function to post-process model outputs before evaluation.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
33
|
+
|
|
34
|
+
self.model = model
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> ModelOutput:
|
|
38
|
+
return self.model(batch)
|
|
39
|
+
|
|
40
|
+
@override
|
|
41
|
+
def validation_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
42
|
+
return self._batch_step(batch)
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def test_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
46
|
+
return self._batch_step(batch)
|
|
47
|
+
|
|
48
|
+
def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
|
|
49
|
+
text, _, targets, metadata = TextImageBatch(*batch)
|
|
50
|
+
output = self.forward(batch)
|
|
51
|
+
return {
|
|
52
|
+
"inputs": text,
|
|
53
|
+
"predictions": output.pop("generated_text"), # type: ignore
|
|
54
|
+
"targets": targets,
|
|
55
|
+
"metadata": metadata,
|
|
56
|
+
} | output
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Multimodal networks API."""
|
|
2
|
+
|
|
3
|
+
from eva.multimodal.models.networks.alibaba import Qwen25VL7BInstruct
|
|
4
|
+
from eva.multimodal.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
|
|
5
|
+
from eva.multimodal.models.networks.others import PathoR13b
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Claude35Sonnet20240620",
|
|
10
|
+
"Claude37Sonnet20250219",
|
|
11
|
+
"PathoR13b",
|
|
12
|
+
"Qwen25VL7BInstruct",
|
|
13
|
+
"model_registry",
|
|
14
|
+
]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Models from Alibaba."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@model_registry.register("alibaba/qwen2-5-vl-7b-instruct")
|
|
10
|
+
class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
|
|
11
|
+
"""Qwen2.5-VL 7B Instruct model."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
system_prompt: str | None = None,
|
|
16
|
+
cache_dir: str | None = None,
|
|
17
|
+
attn_implementation: str = "flash_attention_2",
|
|
18
|
+
):
|
|
19
|
+
"""Initialize the model."""
|
|
20
|
+
super().__init__(
|
|
21
|
+
model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
|
|
22
|
+
model_class="Qwen2_5_VLForConditionalGeneration",
|
|
23
|
+
model_kwargs={
|
|
24
|
+
"torch_dtype": torch.bfloat16,
|
|
25
|
+
"trust_remote_code": True,
|
|
26
|
+
"cache_dir": cache_dir,
|
|
27
|
+
"attn_implementation": attn_implementation,
|
|
28
|
+
},
|
|
29
|
+
generation_kwargs={
|
|
30
|
+
"max_new_tokens": 512,
|
|
31
|
+
"do_sample": False,
|
|
32
|
+
},
|
|
33
|
+
processor_kwargs={
|
|
34
|
+
"padding": True,
|
|
35
|
+
"padding_side": "left",
|
|
36
|
+
"max_pixels": 451584, # 672*672
|
|
37
|
+
},
|
|
38
|
+
system_prompt=system_prompt,
|
|
39
|
+
image_key="images",
|
|
40
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Models from other providers (non-major entities)."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.utils import requirements
|
|
8
|
+
from eva.multimodal.models import wrappers
|
|
9
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@model_registry.register("others/wenchuanzhang_patho-r1-3b")
|
|
13
|
+
class PathoR13b(wrappers.HuggingFaceModel):
|
|
14
|
+
"""Patho-R1-3B model by Wenchuan Zhang."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
system_prompt: str | None = None,
|
|
19
|
+
cache_dir: str | None = None,
|
|
20
|
+
attn_implementation: str = "flash_attention_2",
|
|
21
|
+
):
|
|
22
|
+
"""Initialize the Patho-R1-3B model."""
|
|
23
|
+
requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
24
|
+
|
|
25
|
+
if not os.getenv("HF_TOKEN"):
|
|
26
|
+
raise ValueError("HF_TOKEN env variable must be set.")
|
|
27
|
+
|
|
28
|
+
super().__init__(
|
|
29
|
+
model_name_or_path="WenchuanZhang/Patho-R1-3B",
|
|
30
|
+
model_class="Qwen2_5_VLForConditionalGeneration",
|
|
31
|
+
model_kwargs={
|
|
32
|
+
"torch_dtype": torch.float16,
|
|
33
|
+
"trust_remote_code": True,
|
|
34
|
+
"cache_dir": cache_dir,
|
|
35
|
+
"attn_implementation": attn_implementation,
|
|
36
|
+
},
|
|
37
|
+
generation_kwargs={
|
|
38
|
+
"max_new_tokens": 512,
|
|
39
|
+
"do_sample": False,
|
|
40
|
+
},
|
|
41
|
+
processor_kwargs={
|
|
42
|
+
"padding": True,
|
|
43
|
+
"padding_side": "left",
|
|
44
|
+
"max_pixels": 451584, # 672*672
|
|
45
|
+
},
|
|
46
|
+
system_prompt=system_prompt,
|
|
47
|
+
image_key="images",
|
|
48
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Type definitions for multimodal models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Generic, List, TypeVar
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import NamedTuple
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextImageBatch(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text and image sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: List[MessageSeries]
|
|
18
|
+
"""A batch of conversations with one or multiple messages each."""
|
|
19
|
+
|
|
20
|
+
image: List[tv_tensors.Image]
|
|
21
|
+
"""Image tensor."""
|
|
22
|
+
|
|
23
|
+
target: TargetType | None
|
|
24
|
+
"""Target data."""
|
|
25
|
+
|
|
26
|
+
metadata: Dict[str, Any] | None
|
|
27
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Multimodal Wrapper API."""
|
|
2
|
+
|
|
3
|
+
from eva.multimodal.models.wrappers.base import VisionLanguageModel
|
|
4
|
+
from eva.multimodal.models.wrappers.from_registry import ModelFromRegistry
|
|
5
|
+
from eva.multimodal.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.multimodal.models.wrappers.litellm import LiteLLMModel
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"HuggingFaceModel",
|
|
10
|
+
"LiteLLMModel",
|
|
11
|
+
"ModelFromRegistry",
|
|
12
|
+
"VisionLanguageModel",
|
|
13
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Base class for vision language model wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
+
from eva.language.models.typings import ModelOutput
|
|
11
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VisionLanguageModel(base.BaseModel[TextImageBatch, ModelOutput]):
|
|
15
|
+
"""Base class for multimodal models.
|
|
16
|
+
|
|
17
|
+
Classes that inherit from this should implement the following methods:
|
|
18
|
+
- `load_model`: Loads & instantiates the model.
|
|
19
|
+
- `model_forward`: Implements the forward pass of the model. For API models,
|
|
20
|
+
this can be an API call.
|
|
21
|
+
- `format_inputs`: Preprocesses and converts the input batch into the format
|
|
22
|
+
expected by the `model_forward` method.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self, system_prompt: str | None, output_transforms: Callable | None = None
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Creates a new model instance.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
system_prompt: The system prompt to use for the model (optional).
|
|
32
|
+
output_transforms: Optional transforms to apply to the output of
|
|
33
|
+
the model's forward pass.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(transforms=output_transforms)
|
|
36
|
+
|
|
37
|
+
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def forward(self, batch: TextImageBatch) -> ModelOutput:
|
|
41
|
+
"""Forward pass of the model."""
|
|
42
|
+
inputs = self.format_inputs(batch)
|
|
43
|
+
return super().forward(inputs)
|
|
44
|
+
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def format_inputs(self, batch: TextImageBatch) -> Any:
|
|
47
|
+
"""Converts the inputs into the format expected by the model."""
|
|
48
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Vision backbone helper class."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.core.utils import factory
|
|
10
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
11
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFromRegistry(base.BaseModel[TextImageBatch, List[str]]):
|
|
15
|
+
"""Wrapper class for vision backbone models.
|
|
16
|
+
|
|
17
|
+
This class can be used by load backbones available in eva's
|
|
18
|
+
model registry by name. New backbones can be registered by using
|
|
19
|
+
the `@backbone_registry.register(model_name)` decorator.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model_name: str,
|
|
25
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
+
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
27
|
+
transforms: Callable | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes the model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model_name: The name of the model to load.
|
|
33
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
34
|
+
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
35
|
+
transforms: The transforms to apply to the output tensor
|
|
36
|
+
produced by the model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(transforms=transforms)
|
|
39
|
+
|
|
40
|
+
self._model_name = model_name
|
|
41
|
+
self._model_kwargs = model_kwargs or {}
|
|
42
|
+
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
43
|
+
|
|
44
|
+
self.model = self.load_model()
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|
|
49
|
+
|
|
50
|
+
return factory.ModuleFactory(
|
|
51
|
+
registry=model_registry,
|
|
52
|
+
name=self._model_name,
|
|
53
|
+
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
54
|
+
)
|