kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,92 @@
1
+ """Base classes for text-image datasets."""
2
+
3
+ import abc
4
+ from typing import Any, Dict, Generic
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.language.data.datasets.base import LanguageDataset
9
+ from eva.language.data.datasets.schemas import TransformsSchema
10
+ from eva.language.data.datasets.typings import TargetType, TextSample
11
+ from eva.language.data.messages import MessageSeries
12
+
13
+
14
+ class TextDataset(LanguageDataset[TextSample[TargetType]], abc.ABC, Generic[TargetType]):
15
+ """Base dataset class for text-based tasks."""
16
+
17
+ def __init__(self, *args, transforms: TransformsSchema | None = None, **kwargs) -> None:
18
+ """Initializes the dataset.
19
+
20
+ Args:
21
+ *args: Positional arguments for the base class.
22
+ transforms: The transforms to apply to the text and target when
23
+ loading the samples.
24
+ **kwargs: Keyword arguments for the base class.
25
+ """
26
+ super().__init__(*args, **kwargs)
27
+
28
+ self.transforms = transforms
29
+
30
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
31
+ """Returns the dataset metadata.
32
+
33
+ Args:
34
+ index: The index of the data sample.
35
+
36
+ Returns:
37
+ The sample metadata.
38
+ """
39
+
40
+ @abc.abstractmethod
41
+ def load_text(self, index: int) -> MessageSeries:
42
+ """Returns the text content.
43
+
44
+ Args:
45
+ index: The index of the data sample.
46
+
47
+ Returns:
48
+ The text content.
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @abc.abstractmethod
53
+ def load_target(self, index: int) -> TargetType:
54
+ """Returns the target label.
55
+
56
+ Args:
57
+ index: The index of the data sample.
58
+
59
+ Returns:
60
+ The target label.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ @override
65
+ def __getitem__(self, index: int) -> TextSample[TargetType]:
66
+ item = TextSample(
67
+ text=self.load_text(index),
68
+ target=self.load_target(index),
69
+ metadata=self.load_metadata(index) or {},
70
+ )
71
+ return self._apply_transforms(item)
72
+
73
+ def _apply_transforms(self, sample: TextSample[TargetType]) -> TextSample[TargetType]:
74
+ """Applies the dataset transforms to the text and target.
75
+
76
+ Args:
77
+ sample: The text sample..
78
+
79
+ Returns:
80
+ The transformed sample.
81
+ """
82
+ if self.transforms:
83
+ text = self.transforms.text(sample.text) if self.transforms.text else sample.text
84
+ target = (
85
+ self.transforms.target(sample.target) if self.transforms.target else sample.target
86
+ )
87
+ return TextSample(
88
+ text=text,
89
+ target=target,
90
+ metadata=sample.metadata,
91
+ )
92
+ return sample
@@ -0,0 +1,39 @@
1
+ """Typings for multimodal datasets."""
2
+
3
+ from typing import Any, Generic, TypeVar
4
+
5
+ from typing_extensions import NamedTuple
6
+
7
+ from eva.language.data.messages import MessageSeries
8
+
9
+ TargetType = TypeVar("TargetType")
10
+ """The target data type."""
11
+
12
+
13
+ class TextSample(NamedTuple, Generic[TargetType]):
14
+ """Text sample with target and metadata."""
15
+
16
+ text: MessageSeries
17
+ """One or multiple conversation messages."""
18
+
19
+ target: TargetType | None
20
+ """Target data."""
21
+
22
+ metadata: dict[str, Any] | None
23
+ """Additional metadata."""
24
+
25
+
26
+ class PredictionSample(NamedTuple, Generic[TargetType]):
27
+ """Text sample with target and metadata."""
28
+
29
+ prediction: TargetType
30
+ """Prediction data."""
31
+
32
+ target: TargetType
33
+ """Target data."""
34
+
35
+ text: MessageSeries | None
36
+ """Conversation messages that were used as input."""
37
+
38
+ metadata: dict[str, Any] | None
39
+ """Additional metadata."""
@@ -0,0 +1,60 @@
1
+ """Types and classes for conversation messages in a multimodal context."""
2
+
3
+ import dataclasses
4
+ import enum
5
+ from typing import Any, Dict, List
6
+
7
+
8
+ class Role(str, enum.Enum):
9
+ """Roles for messages in a conversation."""
10
+
11
+ USER = "user"
12
+ ASSISTANT = "assistant"
13
+ SYSTEM = "system"
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Message:
18
+ """Base class for a message in a conversation."""
19
+
20
+ content: str
21
+ role: str
22
+
23
+ def to_dict(self) -> Dict[str, Any]:
24
+ """Convert the message to a dictionary."""
25
+ return dataclasses.asdict(self)
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class UserMessage(Message):
30
+ """User message in a conversation."""
31
+
32
+ role: str = Role.USER
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class AssistantMessage(Message):
37
+ """Assistant message in a conversation."""
38
+
39
+ role: str = Role.ASSISTANT
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class SystemMessage(Message):
44
+ """System message in a conversation."""
45
+
46
+ role: str = Role.SYSTEM
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class ModelSystemMessage(SystemMessage):
51
+ """System message for model-specific instructions."""
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class TaskSystemMessage(SystemMessage):
56
+ """System message for task-specific instructions."""
57
+
58
+
59
+ MessageSeries = List[Message]
60
+ """A series of conversation messages, can contain a mix of system, user, and AI messages."""
@@ -1,25 +1,29 @@
1
1
  """Language Models API."""
2
2
 
3
- from eva.language.models import modules, wrappers
4
- from eva.language.models.modules import TextModule
5
- from eva.language.models.wrappers import HuggingFaceTextModel, LiteLLMTextModel
3
+ from eva.language.models import modules, networks, wrappers
4
+ from eva.language.models.modules import LanguageModule, OfflineLanguageModule
5
+ from eva.language.models.wrappers import HuggingFaceModel, LiteLLMModel
6
6
 
7
7
  try:
8
- from eva.language.models.wrappers import VLLMTextModel
8
+ from eva.language.models.wrappers import VllmModel
9
9
 
10
10
  __all__ = [
11
11
  "modules",
12
12
  "wrappers",
13
- "TextModule",
14
- "HuggingFaceTextModel",
15
- "LiteLLMTextModel",
16
- "VLLMTextModel",
13
+ "networks",
14
+ "HuggingFaceModel",
15
+ "LiteLLMModel",
16
+ "VllmModel",
17
+ "LanguageModule",
18
+ "OfflineLanguageModule",
17
19
  ]
18
20
  except ImportError:
19
21
  __all__ = [
20
22
  "modules",
21
23
  "wrappers",
22
- "TextModule",
23
- "HuggingFaceTextModel",
24
- "LiteLLMTextModel",
24
+ "networks",
25
+ "HuggingFaceModel",
26
+ "LiteLLMModel",
27
+ "LanguageModule",
28
+ "OfflineLanguageModule",
25
29
  ]
@@ -1,5 +1,5 @@
1
1
  """Language Networks API."""
2
2
 
3
- from eva.language.models.modules.text import TextModule
3
+ from eva.language.models.modules.language import LanguageModule, OfflineLanguageModule
4
4
 
5
- __all__ = ["TextModule"]
5
+ __all__ = ["LanguageModule", "OfflineLanguageModule"]
@@ -0,0 +1,94 @@
1
+ """Model module for language models."""
2
+
3
+ from typing import Any
4
+
5
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.metrics import structs as metrics_lib
10
+ from eva.core.models.modules import module
11
+ from eva.core.models.modules.utils import batch_postprocess
12
+ from eva.language.models.typings import ModelOutput, PredictionBatch, TextBatch
13
+
14
+
15
+ class LanguageModule(module.ModelModule):
16
+ """Model module for language tasks."""
17
+
18
+ def __init__(
19
+ self,
20
+ model: nn.Module,
21
+ metrics: metrics_lib.MetricsSchema | None = None,
22
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
23
+ ) -> None:
24
+ """Initializes the text inference module.
25
+
26
+ Args:
27
+ model: Model instance to use for forward pass.
28
+ metrics: Metrics schema for evaluation.
29
+ postprocess: A helper function to post-process model outputs before evaluation.
30
+ """
31
+ super().__init__(metrics=metrics, postprocess=postprocess)
32
+
33
+ self.model = model
34
+
35
+ @override
36
+ def forward(self, batch: TextBatch, *args: Any, **kwargs: Any) -> ModelOutput:
37
+ return self.model(batch)
38
+
39
+ @override
40
+ def validation_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
41
+ return self._batch_step(batch)
42
+
43
+ @override
44
+ def test_step(self, batch: TextBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
45
+ return self._batch_step(batch)
46
+
47
+ def _batch_step(self, batch: TextBatch) -> STEP_OUTPUT:
48
+ text, targets, metadata = TextBatch(*batch)
49
+ output = self.forward(batch)
50
+
51
+ return {
52
+ "inputs": text,
53
+ "predictions": output.pop("generated_text"), # type: ignore
54
+ "targets": targets,
55
+ "metadata": metadata,
56
+ } | output
57
+
58
+
59
+ class OfflineLanguageModule(module.ModelModule):
60
+ """Model module for offline language tasks."""
61
+
62
+ def __init__(
63
+ self,
64
+ metrics: metrics_lib.MetricsSchema | None = None,
65
+ postprocess: batch_postprocess.BatchPostProcess | None = None,
66
+ ) -> None:
67
+ """Initializes the text inference module.
68
+
69
+ Args:
70
+ metrics: Metrics schema for evaluation.
71
+ postprocess: A helper function to post-process model outputs before evaluation.
72
+ """
73
+ super().__init__(metrics=metrics, postprocess=postprocess)
74
+
75
+ @override
76
+ def forward(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> PredictionBatch:
77
+ return batch
78
+
79
+ @override
80
+ def validation_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
81
+ return self._batch_step(batch)
82
+
83
+ @override
84
+ def test_step(self, batch: PredictionBatch, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
85
+ return self._batch_step(batch)
86
+
87
+ def _batch_step(self, batch: PredictionBatch) -> STEP_OUTPUT:
88
+ predictions, targets, text, metadata = PredictionBatch(*batch)
89
+ return {
90
+ "inputs": text,
91
+ "predictions": predictions,
92
+ "targets": targets,
93
+ "metadata": metadata,
94
+ }
@@ -0,0 +1,12 @@
1
+ """Language networks API."""
2
+
3
+ from eva.language.models.networks.alibaba import Qwen205BInstruct
4
+ from eva.language.models.networks.api import Claude35Sonnet20240620, Claude37Sonnet20250219
5
+ from eva.language.models.networks.registry import model_registry
6
+
7
+ __all__ = [
8
+ "Claude35Sonnet20240620",
9
+ "Claude37Sonnet20250219",
10
+ "Qwen205BInstruct",
11
+ "model_registry",
12
+ ]
@@ -0,0 +1,26 @@
1
+ """Models from Alibaba."""
2
+
3
+ import torch
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ @model_registry.register("alibaba/qwen2-0-5b-instruct")
10
+ class Qwen205BInstruct(wrappers.HuggingFaceModel):
11
+ """Qwen2 0.5B Instruct model."""
12
+
13
+ def __init__(self, system_prompt: str | None = None, cache_dir: str | None = None):
14
+ """Initialize the model."""
15
+ super().__init__(
16
+ model_name_or_path="Qwen/Qwen2-0.5B-Instruct",
17
+ model_kwargs={
18
+ "torch_dtype": torch.bfloat16,
19
+ "cache_dir": cache_dir,
20
+ },
21
+ generation_kwargs={
22
+ "max_new_tokens": 512,
23
+ },
24
+ system_prompt=system_prompt,
25
+ chat_mode=True,
26
+ )
@@ -0,0 +1,11 @@
1
+ """Multimodal API networks."""
2
+
3
+ from eva.language.models.networks.api.anthropic import (
4
+ Claude35Sonnet20240620,
5
+ Claude37Sonnet20250219,
6
+ )
7
+
8
+ __all__ = [
9
+ "Claude35Sonnet20240620",
10
+ "Claude37Sonnet20250219",
11
+ ]
@@ -0,0 +1,34 @@
1
+ """Models from Anthropic."""
2
+
3
+ import os
4
+
5
+ from eva.language.models import wrappers
6
+ from eva.language.models.networks.registry import model_registry
7
+
8
+
9
+ class _Claude(wrappers.LiteLLMModel):
10
+ """Base class for Claude models."""
11
+
12
+ def __init__(self, model_name: str, system_prompt: str | None = None):
13
+ if not os.getenv("ANTHROPIC_API_KEY"):
14
+ raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
15
+
16
+ super().__init__(model_name=model_name, system_prompt=system_prompt)
17
+
18
+
19
+ @model_registry.register("anthropic/claude-3-5-sonnet-20240620")
20
+ class Claude35Sonnet20240620(_Claude):
21
+ """Claude 3.5 Sonnet (June 2024) model."""
22
+
23
+ def __init__(self, system_prompt: str | None = None):
24
+ """Initialize the model."""
25
+ super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
26
+
27
+
28
+ @model_registry.register("anthropic/claude-3-7-sonnet-20250219")
29
+ class Claude37Sonnet20250219(_Claude):
30
+ """Claude 3.7 Sonnet (February 2025) model."""
31
+
32
+ def __init__(self, system_prompt: str | None = None):
33
+ """Initialize the model."""
34
+ super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
@@ -0,0 +1,5 @@
1
+ """Language Model Registry."""
2
+
3
+ from eva.core.utils.registry import Registry
4
+
5
+ model_registry = Registry()
@@ -0,0 +1,56 @@
1
+ """Type definitions for language models."""
2
+
3
+ from typing import Any, Dict, Generic, List, TypedDict, TypeVar
4
+
5
+ import torch
6
+ from typing_extensions import NamedTuple, NotRequired
7
+
8
+ from eva.language.data.messages import MessageSeries
9
+
10
+ TargetType = TypeVar("TargetType")
11
+ """The target data type."""
12
+
13
+
14
+ class TextBatch(NamedTuple, Generic[TargetType]):
15
+ """Text sample with target and metadata."""
16
+
17
+ text: List[MessageSeries]
18
+ """Text content."""
19
+
20
+ target: TargetType | None
21
+ """Target data."""
22
+
23
+ metadata: Dict[str, Any] | None
24
+ """Additional metadata."""
25
+
26
+
27
+ class PredictionBatch(NamedTuple, Generic[TargetType]):
28
+ """Text sample with target and metadata."""
29
+
30
+ prediction: TargetType
31
+ """Prediction data."""
32
+
33
+ target: TargetType
34
+ """Target data."""
35
+
36
+ text: List[MessageSeries] | None
37
+ """Conversation messages that were used as input."""
38
+
39
+ metadata: Dict[str, Any] | None
40
+ """Additional metadata."""
41
+
42
+
43
+ class ModelOutput(TypedDict):
44
+ """The output batch produced by the model forward pass."""
45
+
46
+ generated_text: List[str]
47
+ """The text generated by the model."""
48
+
49
+ input_ids: NotRequired[torch.Tensor | None]
50
+ """The token ids of the input text."""
51
+
52
+ output_ids: NotRequired[torch.Tensor | None]
53
+ """The token ids of the model output (usually containing both input and prediction)."""
54
+
55
+ attention_mask: NotRequired[torch.Tensor | None]
56
+ """The attention mask for the input tokens."""
@@ -1,11 +1,19 @@
1
1
  """Language Model Wrappers API."""
2
2
 
3
- from eva.language.models.wrappers.huggingface import HuggingFaceTextModel
4
- from eva.language.models.wrappers.litellm import LiteLLMTextModel
3
+ from eva.language.models.wrappers.base import LanguageModel
4
+ from eva.language.models.wrappers.from_registry import ModelFromRegistry
5
+ from eva.language.models.wrappers.huggingface import HuggingFaceModel
6
+ from eva.language.models.wrappers.litellm import LiteLLMModel
5
7
 
6
8
  try:
7
- from eva.language.models.wrappers.vllm import VLLMTextModel
9
+ from eva.language.models.wrappers.vllm import VllmModel
8
10
 
9
- __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel", "VLLMTextModel"]
11
+ __all__ = [
12
+ "LanguageModel",
13
+ "HuggingFaceModel",
14
+ "LiteLLMModel",
15
+ "VllmModel",
16
+ "ModelFromRegistry",
17
+ ]
10
18
  except ImportError:
11
- __all__ = ["HuggingFaceTextModel", "LiteLLMTextModel"]
19
+ __all__ = ["LanguageModel", "HuggingFaceModel", "LiteLLMModel", "ModelFromRegistry"]
@@ -0,0 +1,47 @@
1
+ """Base class for language model wrappers."""
2
+
3
+ import abc
4
+ from typing import Any, Callable
5
+
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.language.data.messages import ModelSystemMessage
10
+ from eva.language.models.typings import ModelOutput, TextBatch
11
+
12
+
13
+ class LanguageModel(base.BaseModel[TextBatch, ModelOutput]):
14
+ """Base class for language models.
15
+
16
+ Classes that inherit from this should implement the following methods:
17
+ - `load_model`: Loads & instantiates the model.
18
+ - `model_forward`: Implements the forward pass of the model. For API models,
19
+ this can be an API call.
20
+ - `format_inputs`: Preprocesses and converts the input batch into the format
21
+ expected by the `model_forward` method.
22
+ """
23
+
24
+ def __init__(
25
+ self, system_prompt: str | None, output_transforms: Callable | None = None
26
+ ) -> None:
27
+ """Creates a new model instance.
28
+
29
+ Args:
30
+ system_prompt: The system prompt to use for the model (optional).
31
+ output_transforms: Optional transforms to apply to the output of
32
+ the model's forward pass.
33
+ """
34
+ super().__init__(transforms=output_transforms)
35
+
36
+ self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
37
+
38
+ @override
39
+ def forward(self, batch: TextBatch) -> ModelOutput:
40
+ """Forward pass of the model."""
41
+ inputs = self.format_inputs(batch)
42
+ return super().forward(inputs)
43
+
44
+ @abc.abstractmethod
45
+ def format_inputs(self, batch: TextBatch) -> Any:
46
+ """Converts the inputs into the format expected by the model."""
47
+ raise NotImplementedError
@@ -0,0 +1,54 @@
1
+ """Vision backbone helper class."""
2
+
3
+ from typing import Any, Callable, Dict, List
4
+
5
+ from torch import nn
6
+ from typing_extensions import override
7
+
8
+ from eva.core.models.wrappers import base
9
+ from eva.core.utils import factory
10
+ from eva.language.models.networks.registry import model_registry
11
+ from eva.language.models.typings import TextBatch
12
+
13
+
14
+ class ModelFromRegistry(base.BaseModel[TextBatch, List[str]]):
15
+ """Wrapper class for vision backbone models.
16
+
17
+ This class can be used by load backbones available in eva's
18
+ model registry by name. New backbones can be registered by using
19
+ the `@backbone_registry.register(model_name)` decorator.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ model_extra_kwargs: Dict[str, Any] | None = None,
27
+ transforms: Callable | None = None,
28
+ ) -> None:
29
+ """Initializes the model.
30
+
31
+ Args:
32
+ model_name: The name of the model to load.
33
+ model_kwargs: The arguments used for instantiating the model.
34
+ model_extra_kwargs: Extra arguments used for instantiating the model.
35
+ transforms: The transforms to apply to the output tensor
36
+ produced by the model.
37
+ """
38
+ super().__init__(transforms=transforms)
39
+
40
+ self._model_name = model_name
41
+ self._model_kwargs = model_kwargs or {}
42
+ self._model_extra_kwargs = model_extra_kwargs or {}
43
+
44
+ self.model = self.load_model()
45
+
46
+ @override
47
+ def load_model(self) -> nn.Module:
48
+ ModelFromRegistry.__name__ = self._model_name
49
+
50
+ return factory.ModuleFactory(
51
+ registry=model_registry,
52
+ name=self._model_name,
53
+ init_args=self._model_kwargs | self._model_extra_kwargs,
54
+ )