kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,5 @@
1
+ """Multimodal dataloaders API."""
2
+
3
+ from eva.multimodal.data.dataloaders.collate_fn import text_image_collate
4
+
5
+ __all__ = ["text_image_collate"]
@@ -0,0 +1,5 @@
1
+ """Multimodal collate functions API."""
2
+
3
+ from eva.multimodal.data.dataloaders.collate_fn.text_image import text_image_collate
4
+
5
+ __all__ = ["text_image_collate"]
@@ -0,0 +1,28 @@
1
+ """Collate functions for text-image data."""
2
+
3
+ from typing import List
4
+
5
+ from torch.utils.data._utils.collate import default_collate
6
+
7
+ from eva.multimodal.data.datasets.typings import TextImageSample
8
+ from eva.multimodal.models.typings import TextImageBatch
9
+
10
+
11
+ def text_image_collate(batch: List[TextImageSample]) -> TextImageBatch:
12
+ """Collate function for text-image batches."""
13
+ texts, images, targets, metadata = zip(*batch, strict=False)
14
+
15
+ first_sample = batch[0]
16
+ metadata = None
17
+ if first_sample.metadata is not None:
18
+ metadata = {
19
+ k: [sample.metadata[k] for sample in batch if sample.metadata]
20
+ for k in first_sample.metadata.keys()
21
+ }
22
+
23
+ return TextImageBatch(
24
+ text=list(texts),
25
+ image=list(images),
26
+ target=default_collate(targets) if targets[0] is not None else None,
27
+ metadata=metadata,
28
+ )
@@ -0,0 +1,6 @@
1
+ """Multimodal datasets API."""
2
+
3
+ from eva.multimodal.data.datasets.multiple_choice.patch_camelyon import PatchCamelyon
4
+ from eva.multimodal.data.datasets.text_image import TextImageDataset
5
+
6
+ __all__ = ["TextImageDataset", "PatchCamelyon"]
@@ -0,0 +1,13 @@
1
+ """Multimodal Dataset base class."""
2
+
3
+ import abc
4
+ from typing import Generic, TypeVar
5
+
6
+ from eva.core.data.datasets import base
7
+
8
+ DataSample = TypeVar("DataSample")
9
+ """The data sample type."""
10
+
11
+
12
+ class MultimodalDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
+ """Base dataset class for multimodal tasks."""
@@ -0,0 +1,5 @@
1
+ """Multiple choice datasets."""
2
+
3
+ from eva.multimodal.data.datasets.multiple_choice.patch_camelyon import PatchCamelyon
4
+
5
+ __all__ = ["PatchCamelyon"]
@@ -0,0 +1,80 @@
1
+ """PatchCamelyon dataset with text prompts for multimodal classification."""
2
+
3
+ from typing import Any, Dict, Literal
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import override
7
+
8
+ from eva.language.data.messages import MessageSeries, UserMessage
9
+ from eva.multimodal.data.datasets.schemas import TransformsSchema
10
+ from eva.multimodal.data.datasets.text_image import TextImageDataset
11
+ from eva.vision.data import datasets as vision_datasets
12
+
13
+
14
+ class PatchCamelyon(TextImageDataset[int], vision_datasets.PatchCamelyon):
15
+ """PatchCamelyon image classification using a multiple choice text prompt."""
16
+
17
+ _default_prompt = (
18
+ "You are a pathology expert helping pathologists to analyze images of tissue samples.\n"
19
+ "Question: Does this image show metastatic breast tissue?\n"
20
+ "Options: A: no, B: yes\n"
21
+ "Only answer with a single letter without further explanation. "
22
+ "Please always provide an answer, even if you are not sure.\n"
23
+ "Answer: "
24
+ )
25
+
26
+ def __init__(
27
+ self,
28
+ root: str,
29
+ split: Literal["train", "val", "test"],
30
+ download: bool = False,
31
+ transforms: TransformsSchema | None = None,
32
+ prompt: str | None = None,
33
+ max_samples: int | None = None,
34
+ ) -> None:
35
+ """Initializes the dataset.
36
+
37
+ Args:
38
+ root: The path to the dataset root. This path should contain
39
+ the uncompressed h5 files and the metadata.
40
+ split: The dataset split for training, validation, or testing.
41
+ download: Whether to download the data for the specified split.
42
+ Note that the download will be executed only by additionally
43
+ calling the :meth:`prepare_data` method.
44
+ transforms: A function/transform which returns a transformed
45
+ version of the raw data samples.
46
+ prompt: The text prompt to use for classification (multple choice).
47
+ max_samples: Maximum number of samples to use. If None, use all samples.
48
+ """
49
+ super().__init__(root=root, split=split, download=download, transforms=transforms)
50
+
51
+ self.max_samples = max_samples
52
+ self.prompt = prompt or self._default_prompt
53
+
54
+ if self.max_samples is not None:
55
+ self._expected_length = {split: max_samples}
56
+
57
+ @property
58
+ @override
59
+ def class_to_idx(self) -> Dict[str, int]:
60
+ return {"A": 0, "B": 1}
61
+
62
+ @override
63
+ def __len__(self) -> int:
64
+ return self.max_samples or self._fetch_dataset_length()
65
+
66
+ @override
67
+ def load_text(self, index: int) -> MessageSeries:
68
+ return [UserMessage(content=self.prompt)]
69
+
70
+ @override
71
+ def load_image(self, index: int) -> tv_tensors.Image:
72
+ return vision_datasets.PatchCamelyon.load_data(self, index)
73
+
74
+ @override
75
+ def load_target(self, index: int) -> int:
76
+ return int(vision_datasets.PatchCamelyon.load_target(self, index).item())
77
+
78
+ @override
79
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
80
+ return vision_datasets.PatchCamelyon.load_metadata(self, index)
@@ -0,0 +1,14 @@
1
+ """Schema definitions for dataset classes."""
2
+
3
+ import dataclasses
4
+ from typing import Callable
5
+
6
+ from eva.language.data.datasets import schemas as language_schemas
7
+
8
+
9
+ @dataclasses.dataclass(frozen=True)
10
+ class TransformsSchema(language_schemas.TransformsSchema):
11
+ """Schema for dataset transforms."""
12
+
13
+ image: Callable | None = None
14
+ """Image transformation"""
@@ -0,0 +1,77 @@
1
+ """Base classes for text-image datasets."""
2
+
3
+ import abc
4
+ from typing import Generic
5
+
6
+ from torchvision import tv_tensors
7
+ from typing_extensions import override
8
+
9
+ from eva.language.data.datasets.text import TextDataset
10
+ from eva.multimodal.data.datasets.base import MultimodalDataset
11
+ from eva.multimodal.data.datasets.schemas import TransformsSchema
12
+ from eva.multimodal.data.datasets.typings import TargetType, TextImageSample
13
+
14
+
15
+ class TextImageDataset(
16
+ MultimodalDataset[TextImageSample[TargetType]], TextDataset, abc.ABC, Generic[TargetType]
17
+ ):
18
+ """Base dataset class for text-image tasks."""
19
+
20
+ def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
21
+ """Initializes the dataset.
22
+
23
+ Args:
24
+ *args: Positional arguments for the base class.
25
+ transforms: The transforms to apply to the text, image and target when
26
+ loading the samples.
27
+ **kwargs: Keyword arguments for the base class.
28
+ """
29
+ super().__init__(*args, **kwargs)
30
+
31
+ self.transforms = transforms
32
+
33
+ @abc.abstractmethod
34
+ def load_image(self, index: int) -> tv_tensors.Image:
35
+ """Returns the image content.
36
+
37
+ Args:
38
+ index: The index of the data sample.
39
+
40
+ Returns:
41
+ The image content.
42
+ """
43
+ raise NotImplementedError
44
+
45
+ @override
46
+ def __getitem__(self, index: int) -> TextImageSample[TargetType]:
47
+ item = TextImageSample(
48
+ text=self.load_text(index),
49
+ image=self.load_image(index),
50
+ target=self.load_target(index),
51
+ metadata=self.load_metadata(index) or {},
52
+ )
53
+ return self._apply_transforms(item)
54
+
55
+ @override
56
+ def _apply_transforms(self, sample: TextImageSample[TargetType]) -> TextImageSample[TargetType]:
57
+ """Applies the dataset transforms to the text, image and target.
58
+
59
+ Args:
60
+ sample: The sample containing text, image, target and metadata.
61
+
62
+ Returns:
63
+ The transformed sample.
64
+ """
65
+ if self.transforms:
66
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
67
+ image = self.transforms.image(sample.image) if self.transforms.image else sample.image
68
+ target = (
69
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
70
+ )
71
+ return TextImageSample(
72
+ text=text,
73
+ image=image,
74
+ target=target,
75
+ metadata=sample.metadata,
76
+ )
77
+ return sample
@@ -0,0 +1,27 @@
1
+ """Typings for multimodal datasets."""
2
+
3
+ from typing import Any, Generic, TypeVar
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import NamedTuple
7
+
8
+ from eva.language.data.messages import MessageSeries
9
+
10
+ TargetType = TypeVar("TargetType")
11
+ """The target data type."""
12
+
13
+
14
+ class TextImageSample(NamedTuple, Generic[TargetType]):
15
+ """Text and image sample with target and metadata."""
16
+
17
+ text: MessageSeries
18
+ """One or multiple conversation messages."""
19
+
20
+ image: tv_tensors.Image
21
+ """Image tensor."""
22
+
23
+ target: TargetType | None
24
+ """Target data."""
25
+
26
+ metadata: dict[str, Any] | None
27
+ """Additional metadata."""
@@ -0,0 +1,8 @@
1
+ """Multimodal models API."""
2
+
3
+ from eva.multimodal.models import networks, wrappers
4
+
5
+ __all__ = [
6
+ "networks",
7
+ "wrappers",
8
+ ]
@@ -0,0 +1,5 @@
1
+ """Multimodal Networks API."""
2
+
3
+ from eva.multimodal.models.modules.vision_language import VisionLanguageModule
4
+
5
+ __all__ = ["VisionLanguageModule"]
@@ -0,0 +1,56 @@
1
+ """Model module for vision-language models."""
2
+
3
+ from typing import Any
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.metrics import structs as metrics_lib
10
+ from eva.core.models.modules import module
11
+ from eva.core.models.modules.utils import batch_postprocess
12
+ from eva.language.models.typings import ModelOutput
13
+ from eva.multimodal.models.typings import TextImageBatch
14
+
15
+
16
+ class VisionLanguageModule(module.ModelModule):
17
+ """Model module for vision-language tasks."""
18
+
19
+ def __init__(
20
+ self,
21
+ model: nn.Module,
22
+ metrics: metrics_lib.MetricsSchema | None = None,
23
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
24
+ ) -> None:
25
+ """Initializes the text inference module.
26
+
27
+ Args:
28
+ model: Model instance to use for forward pass.
29
+ metrics: Metrics schema for evaluation.
30
+ postprocess: A helper function to post-process model outputs before evaluation.
31
+ """
32
+ super().__init__(metrics=metrics, postprocess=postprocess)
33
+
34
+ self.model = model
35
+
36
+ @override
37
+ def forward(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> ModelOutput:
38
+ return self.model(batch)
39
+
40
+ @override
41
+ def validation_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
42
+ return self._batch_step(batch)
43
+
44
+ @override
45
+ def test_step(self, batch: TextImageBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
46
+ return self._batch_step(batch)
47
+
48
+ def _batch_step(self, batch: TextImageBatch) -> STEP_OUTPUT:
49
+ text, _, targets, metadata = TextImageBatch(*batch)
50
+ output = self.forward(batch)
51
+ return {
52
+ "inputs": text,
53
+ "predictions": output.pop("generated_text"), # type: ignore
54
+ "targets": targets,
55
+ "metadata": metadata,
56
+ } | output
@@ -0,0 +1,14 @@
1
+ """Multimodal networks API."""
2
+
3
+ from eva.multimodal.models.networks.alibaba import Qwen25VL7BInstruct
4
+ from eva.multimodal.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
5
+ from eva.multimodal.models.networks.others import PathoR13b
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ "PathoR13b",
12
+ "Qwen25VL7BInstruct",
13
+ "model_registry",
14
+ ]
@@ -0,0 +1,40 @@
1
+ """Models from Alibaba."""
2
+
3
+ import torch
4
+
5
+ from eva.multimodal.models import wrappers
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+
9
+ @model_registry.register("alibaba/qwen2-5-vl-7b-instruct")
10
+ class Qwen25VL7BInstruct(wrappers.HuggingFaceModel):
11
+ """Qwen2.5-VL 7B Instruct model."""
12
+
13
+ def __init__(
14
+ self,
15
+ system_prompt: str | None = None,
16
+ cache_dir: str | None = None,
17
+ attn_implementation: str = "flash_attention_2",
18
+ ):
19
+ """Initialize the model."""
20
+ super().__init__(
21
+ model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
22
+ model_class="Qwen2_5_VLForConditionalGeneration",
23
+ model_kwargs={
24
+ "torch_dtype": torch.bfloat16,
25
+ "trust_remote_code": True,
26
+ "cache_dir": cache_dir,
27
+ "attn_implementation": attn_implementation,
28
+ },
29
+ generation_kwargs={
30
+ "max_new_tokens": 512,
31
+ "do_sample": False,
32
+ },
33
+ processor_kwargs={
34
+ "padding": True,
35
+ "padding_side": "left",
36
+ "max_pixels": 451584, # 672*672
37
+ },
38
+ system_prompt=system_prompt,
39
+ image_key="images",
40
+ )
@@ -0,0 +1,11 @@
1
+ """Multimodal API networks."""
2
+
3
+ from eva.multimodal.models.networks.api.anthropic import (
4
+ Claude35Sonnet20240620,
5
+ Claude37Sonnet20250219,
6
+ )
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ ]
@@ -0,0 +1,34 @@
1
+ """Models from Anthropic."""
2
+
3
+ import os
4
+
5
+ from eva.multimodal.models import wrappers
6
+ from eva.multimodal.models.networks.registry import model_registry
7
+
8
+
9
+ class _Claude(wrappers.LiteLLMModel):
10
+ """Base class for Claude models."""
11
+
12
+ def __init__(self, model_name: str, system_prompt: str | None = None):
13
+ if not os.getenv("ANTHROPIC_API_KEY"):
14
+ raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
15
+
16
+ super().__init__(model_name=model_name, system_prompt=system_prompt)
17
+
18
+
19
+ @model_registry.register("anthropic/claude-3-5-sonnet-20240620")
20
+ class Claude35Sonnet20240620(_Claude):
21
+ """Claude 3.5 Sonnet (June 2024) model."""
22
+
23
+ def __init__(self, system_prompt: str | None = None):
24
+ """Initialize the model."""
25
+ super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
26
+
27
+
28
+ @model_registry.register("anthropic/claude-3-7-sonnet-20250219")
29
+ class Claude37Sonnet20250219(_Claude):
30
+ """Claude 3.7 Sonnet (February 2025) model."""
31
+
32
+ def __init__(self, system_prompt: str | None = None):
33
+ """Initialize the model."""
34
+ super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
@@ -0,0 +1,48 @@
1
+ """Models from other providers (non-major entities)."""
2
+
3
+ import os
4
+
5
+ import torch
6
+
7
+ from eva.core.utils import requirements
8
+ from eva.multimodal.models import wrappers
9
+ from eva.multimodal.models.networks.registry import model_registry
10
+
11
+
12
+ @model_registry.register("others/wenchuanzhang_patho-r1-3b")
13
+ class PathoR13b(wrappers.HuggingFaceModel):
14
+ """Patho-R1-3B model by Wenchuan Zhang."""
15
+
16
+ def __init__(
17
+ self,
18
+ system_prompt: str | None = None,
19
+ cache_dir: str | None = None,
20
+ attn_implementation: str = "flash_attention_2",
21
+ ):
22
+ """Initialize the Patho-R1-3B model."""
23
+ requirements.check_min_versions(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
24
+
25
+ if not os.getenv("HF_TOKEN"):
26
+ raise ValueError("HF_TOKEN env variable must be set.")
27
+
28
+ super().__init__(
29
+ model_name_or_path="WenchuanZhang/Patho-R1-3B",
30
+ model_class="Qwen2_5_VLForConditionalGeneration",
31
+ model_kwargs={
32
+ "torch_dtype": torch.float16,
33
+ "trust_remote_code": True,
34
+ "cache_dir": cache_dir,
35
+ "attn_implementation": attn_implementation,
36
+ },
37
+ generation_kwargs={
38
+ "max_new_tokens": 512,
39
+ "do_sample": False,
40
+ },
41
+ processor_kwargs={
42
+ "padding": True,
43
+ "padding_side": "left",
44
+ "max_pixels": 451584, # 672*672
45
+ },
46
+ system_prompt=system_prompt,
47
+ image_key="images",
48
+ )
@@ -0,0 +1,5 @@
1
+ """Multimodal Model Registry."""
2
+
3
+ from eva.core.utils.registry import Registry
4
+
5
+ model_registry = Registry()
@@ -0,0 +1,27 @@
1
+ """Type definitions for multimodal models."""
2
+
3
+ from typing import Any, Dict, Generic, List, TypeVar
4
+
5
+ from torchvision import tv_tensors
6
+ from typing_extensions import NamedTuple
7
+
8
+ from eva.language.data.messages import MessageSeries
9
+
10
+ TargetType = TypeVar("TargetType")
11
+ """The target data type."""
12
+
13
+
14
+ class TextImageBatch(NamedTuple, Generic[TargetType]):
15
+ """Text and image sample with target and metadata."""
16
+
17
+ text: List[MessageSeries]
18
+ """A batch of conversations with one or multiple messages each."""
19
+
20
+ image: List[tv_tensors.Image]
21
+ """Image tensor."""
22
+
23
+ target: TargetType | None
24
+ """Target data."""
25
+
26
+ metadata: Dict[str, Any] | None
27
+ """Additional metadata."""
@@ -0,0 +1,13 @@
1
+ """Multimodal Wrapper API."""
2
+
3
+ from eva.multimodal.models.wrappers.base import VisionLanguageModel
4
+ from eva.multimodal.models.wrappers.from_registry import ModelFromRegistry
5
+ from eva.multimodal.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.multimodal.models.wrappers.litellm import LiteLLMModel
7
+
8
+ __all__ = [
9
+ "HuggingFaceModel",
10
+ "LiteLLMModel",
11
+ "ModelFromRegistry",
12
+ "VisionLanguageModel",
13
+ ]
@@ -0,0 +1,48 @@
1
+ """Base class for vision language model wrappers."""
2
+
3
+ import abc
4
+ from typing import Any, Callable
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.language.data.messages import ModelSystemMessage
10
+ from eva.language.models.typings import ModelOutput
11
+ from eva.multimodal.models.typings import TextImageBatch
12
+
13
+
14
+ class VisionLanguageModel(base.BaseModel[TextImageBatch, ModelOutput]):
15
+ """Base class for multimodal models.
16
+
17
+ Classes that inherit from this should implement the following methods:
18
+ - `load_model`: Loads & instantiates the model.
19
+ - `model_forward`: Implements the forward pass of the model. For API models,
20
+ this can be an API call.
21
+ - `format_inputs`: Preprocesses and converts the input batch into the format
22
+ expected by the `model_forward` method.
23
+ """
24
+
25
+ def __init__(
26
+ self, system_prompt: str | None, output_transforms: Callable | None = None
27
+ ) -> None:
28
+ """Creates a new model instance.
29
+
30
+ Args:
31
+ system_prompt: The system prompt to use for the model (optional).
32
+ output_transforms: Optional transforms to apply to the output of
33
+ the model's forward pass.
34
+ """
35
+ super().__init__(transforms=output_transforms)
36
+
37
+ self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
38
+
39
+ @override
40
+ def forward(self, batch: TextImageBatch) -> ModelOutput:
41
+ """Forward pass of the model."""
42
+ inputs = self.format_inputs(batch)
43
+ return super().forward(inputs)
44
+
45
+ @abc.abstractmethod
46
+ def format_inputs(self, batch: TextImageBatch) -> Any:
47
+ """Converts the inputs into the format expected by the model."""
48
+ raise NotImplementedError
@@ -0,0 +1,54 @@
1
+ """Vision backbone helper class."""
2
+
3
+ from typing import Any, Callable, Dict, List
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.core.utils import factory
10
+ from eva.multimodal.models.networks.registry import model_registry
11
+ from eva.multimodal.models.typings import TextImageBatch
12
+
13
+
14
+ class ModelFromRegistry(base.BaseModel[TextImageBatch, List[str]]):
15
+ """Wrapper class for vision backbone models.
16
+
17
+ This class can be used by load backbones available in eva's
18
+ model registry by name. New backbones can be registered by using
19
+ the `@backbone_registry.register(model_name)` decorator.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ model_extra_kwargs: Dict[str, Any] | None = None,
27
+ transforms: Callable | None = None,
28
+ ) -> None:
29
+ """Initializes the model.
30
+
31
+ Args:
32
+ model_name: The name of the model to load.
33
+ model_kwargs: The arguments used for instantiating the model.
34
+ model_extra_kwargs: Extra arguments used for instantiating the model.
35
+ transforms: The transforms to apply to the output tensor
36
+ produced by the model.
37
+ """
38
+ super().__init__(transforms=transforms)
39
+
40
+ self._model_name = model_name
41
+ self._model_kwargs = model_kwargs or {}
42
+ self._model_extra_kwargs = model_extra_kwargs or {}
43
+
44
+ self.model = self.load_model()
45
+
46
+ @override
47
+ def load_model(self) -> nn.Module:
48
+ ModelFromRegistry.__name__ = self._model_name
49
+
50
+ return factory.ModuleFactory(
51
+ registry=model_registry,
52
+ name=self._model_name,
53
+ init_args=self._model_kwargs | self._model_extra_kwargs,
54
+ )