kaiko-eva 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/config.py +4 -0
- eva/core/cli/setup.py +1 -1
- eva/core/data/dataloaders/__init__.py +1 -2
- eva/core/data/dataloaders/dataloader.py +3 -1
- eva/core/data/samplers/random.py +17 -10
- eva/core/interface/interface.py +21 -0
- eva/core/loggers/log/__init__.py +2 -1
- eva/core/loggers/log/table.py +73 -0
- eva/core/models/modules/module.py +2 -2
- eva/core/models/wrappers/base.py +2 -2
- eva/core/models/wrappers/from_function.py +3 -3
- eva/core/models/wrappers/from_torchhub.py +9 -7
- eva/core/models/wrappers/huggingface.py +4 -5
- eva/core/models/wrappers/onnx.py +5 -5
- eva/core/trainers/trainer.py +2 -0
- eva/language/__init__.py +2 -1
- eva/language/callbacks/__init__.py +5 -0
- eva/language/callbacks/writers/__init__.py +5 -0
- eva/language/callbacks/writers/prediction.py +176 -0
- eva/language/data/dataloaders/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/language/data/dataloaders/collate_fn/text.py +57 -0
- eva/language/data/datasets/__init__.py +3 -1
- eva/language/data/datasets/{language.py → base.py} +1 -1
- eva/language/data/datasets/classification/base.py +3 -43
- eva/language/data/datasets/classification/pubmedqa.py +36 -4
- eva/language/data/datasets/prediction.py +151 -0
- eva/language/data/datasets/schemas.py +18 -0
- eva/language/data/datasets/text.py +92 -0
- eva/language/data/datasets/typings.py +39 -0
- eva/language/data/messages.py +60 -0
- eva/language/models/__init__.py +15 -11
- eva/language/models/modules/__init__.py +2 -2
- eva/language/models/modules/language.py +93 -0
- eva/language/models/networks/__init__.py +12 -0
- eva/language/models/networks/alibaba.py +26 -0
- eva/language/models/networks/api/__init__.py +11 -0
- eva/language/models/networks/api/anthropic.py +34 -0
- eva/language/models/networks/registry.py +5 -0
- eva/language/models/typings.py +39 -0
- eva/language/models/wrappers/__init__.py +13 -5
- eva/language/models/wrappers/base.py +47 -0
- eva/language/models/wrappers/from_registry.py +54 -0
- eva/language/models/wrappers/huggingface.py +44 -8
- eva/language/models/wrappers/litellm.py +81 -46
- eva/language/models/wrappers/vllm.py +37 -13
- eva/language/utils/__init__.py +2 -1
- eva/language/utils/str_to_int_tensor.py +20 -12
- eva/language/utils/text/__init__.py +5 -0
- eva/language/utils/text/messages.py +113 -0
- eva/multimodal/__init__.py +6 -0
- eva/multimodal/callbacks/__init__.py +5 -0
- eva/multimodal/callbacks/writers/__init__.py +5 -0
- eva/multimodal/callbacks/writers/prediction.py +39 -0
- eva/multimodal/data/__init__.py +5 -0
- eva/multimodal/data/dataloaders/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
- eva/multimodal/data/datasets/__init__.py +6 -0
- eva/multimodal/data/datasets/base.py +13 -0
- eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
- eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
- eva/multimodal/data/datasets/schemas.py +14 -0
- eva/multimodal/data/datasets/text_image.py +77 -0
- eva/multimodal/data/datasets/typings.py +27 -0
- eva/multimodal/models/__init__.py +8 -0
- eva/multimodal/models/modules/__init__.py +5 -0
- eva/multimodal/models/modules/vision_language.py +55 -0
- eva/multimodal/models/networks/__init__.py +14 -0
- eva/multimodal/models/networks/alibaba.py +39 -0
- eva/multimodal/models/networks/api/__init__.py +11 -0
- eva/multimodal/models/networks/api/anthropic.py +34 -0
- eva/multimodal/models/networks/others.py +47 -0
- eva/multimodal/models/networks/registry.py +5 -0
- eva/multimodal/models/typings.py +27 -0
- eva/multimodal/models/wrappers/__init__.py +13 -0
- eva/multimodal/models/wrappers/base.py +47 -0
- eva/multimodal/models/wrappers/from_registry.py +54 -0
- eva/multimodal/models/wrappers/huggingface.py +180 -0
- eva/multimodal/models/wrappers/litellm.py +56 -0
- eva/multimodal/utils/__init__.py +1 -0
- eva/multimodal/utils/image/__init__.py +5 -0
- eva/multimodal/utils/image/encode.py +28 -0
- eva/multimodal/utils/text/__init__.py +1 -0
- eva/multimodal/utils/text/messages.py +79 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
- eva/vision/data/transforms/__init__.py +2 -1
- eva/vision/data/transforms/spatial/__init__.py +2 -1
- eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
- eva/vision/data/transforms/spatial/functional/resize.py +26 -0
- eva/vision/data/transforms/spatial/resize.py +62 -0
- eva/vision/models/wrappers/from_registry.py +6 -5
- eva/vision/models/wrappers/from_timm.py +6 -4
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/METADATA +10 -2
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/RECORD +98 -40
- eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
- eva/core/data/dataloaders/collate_fn/collate.py +0 -24
- eva/language/models/modules/text.py +0 -85
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.2.dist-info → kaiko_eva-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Models from Anthropic."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from eva.multimodal.models import wrappers
|
|
6
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _Claude(wrappers.LiteLLMModel):
|
|
10
|
+
"""Base class for Claude models."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str, system_prompt: str | None = None):
|
|
13
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
|
14
|
+
raise ValueError("ANTHROPIC_API_KEY env variable must be set.")
|
|
15
|
+
|
|
16
|
+
super().__init__(model_name=model_name, system_prompt=system_prompt)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@model_registry.register("anthropic/claude-3-5-sonnet-20240620")
|
|
20
|
+
class Claude35Sonnet20240620(_Claude):
|
|
21
|
+
"""Claude 3.5 Sonnet (June 2024) model."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, system_prompt: str | None = None):
|
|
24
|
+
"""Initialize the model."""
|
|
25
|
+
super().__init__(model_name="claude-3-5-sonnet-20240620", system_prompt=system_prompt)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@model_registry.register("anthropic/claude-3-7-sonnet-20250219")
|
|
29
|
+
class Claude37Sonnet20250219(_Claude):
|
|
30
|
+
"""Claude 3.7 Sonnet (February 2025) model."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, system_prompt: str | None = None):
|
|
33
|
+
"""Initialize the model."""
|
|
34
|
+
super().__init__(model_name="claude-3-7-sonnet-20250219", system_prompt=system_prompt)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Models from other providers (non-major entities)."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from eva.core.utils import requirements
|
|
8
|
+
from eva.multimodal.models import wrappers
|
|
9
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@model_registry.register("others/wenchuanzhang_patho-r1-3b")
|
|
13
|
+
class PathoR13b(wrappers.HuggingFaceModel):
|
|
14
|
+
"""Patho-R1-3B model by Wenchuan Zhang."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
system_prompt: str | None = None,
|
|
19
|
+
cache_dir: str | None = None,
|
|
20
|
+
attn_implementation: str = "flash_attention_2",
|
|
21
|
+
):
|
|
22
|
+
"""Initialize the Patho-R1-3B model."""
|
|
23
|
+
requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
24
|
+
|
|
25
|
+
if not os.getenv("HF_TOKEN"):
|
|
26
|
+
raise ValueError("HF_TOKEN env variable must be set.")
|
|
27
|
+
|
|
28
|
+
super().__init__(
|
|
29
|
+
model_name_or_path="WenchuanZhang/Patho-R1-3B",
|
|
30
|
+
model_class="Qwen2_5_VLForConditionalGeneration",
|
|
31
|
+
model_kwargs={
|
|
32
|
+
"torch_dtype": torch.float16,
|
|
33
|
+
"trust_remote_code": True,
|
|
34
|
+
"cache_dir": cache_dir,
|
|
35
|
+
"attn_implementation": attn_implementation,
|
|
36
|
+
},
|
|
37
|
+
generation_kwargs={
|
|
38
|
+
"max_new_tokens": 512,
|
|
39
|
+
"do_sample": False,
|
|
40
|
+
},
|
|
41
|
+
processor_kwargs={
|
|
42
|
+
"padding": True,
|
|
43
|
+
"padding_side": "left",
|
|
44
|
+
"max_pixels": 451584, # 672*672
|
|
45
|
+
},
|
|
46
|
+
system_prompt=system_prompt,
|
|
47
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Type definitions for multimodal models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Generic, List, TypeVar
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
from typing_extensions import NamedTuple
|
|
7
|
+
|
|
8
|
+
from eva.language.data.messages import MessageSeries
|
|
9
|
+
|
|
10
|
+
TargetType = TypeVar("TargetType")
|
|
11
|
+
"""The target data type."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextImageBatch(NamedTuple, Generic[TargetType]):
|
|
15
|
+
"""Text and image sample with target and metadata."""
|
|
16
|
+
|
|
17
|
+
text: List[MessageSeries]
|
|
18
|
+
"""A batch of conversations with one or multiple messages each."""
|
|
19
|
+
|
|
20
|
+
image: List[tv_tensors.Image]
|
|
21
|
+
"""Image tensor."""
|
|
22
|
+
|
|
23
|
+
target: TargetType | None
|
|
24
|
+
"""Target data."""
|
|
25
|
+
|
|
26
|
+
metadata: Dict[str, Any] | None
|
|
27
|
+
"""Additional metadata."""
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Multimodal Wrapper API."""
|
|
2
|
+
|
|
3
|
+
from eva.multimodal.models.wrappers.base import VisionLanguageModel
|
|
4
|
+
from eva.multimodal.models.wrappers.from_registry import ModelFromRegistry
|
|
5
|
+
from eva.multimodal.models.wrappers.huggingface import HuggingFaceModel
|
|
6
|
+
from eva.multimodal.models.wrappers.litellm import LiteLLMModel
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"HuggingFaceModel",
|
|
10
|
+
"LiteLLMModel",
|
|
11
|
+
"ModelFromRegistry",
|
|
12
|
+
"VisionLanguageModel",
|
|
13
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Base class for vision language model wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable, List
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.language.data.messages import ModelSystemMessage
|
|
10
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VisionLanguageModel(base.BaseModel[TextImageBatch, List[str]]):
|
|
14
|
+
"""Base class for multimodal models.
|
|
15
|
+
|
|
16
|
+
Classes that inherit from this should implement the following methods:
|
|
17
|
+
- `load_model`: Loads & instantiates the model.
|
|
18
|
+
- `model_forward`: Implements the forward pass of the model. For API models,
|
|
19
|
+
this can be an API call.
|
|
20
|
+
- `format_inputs`: Preprocesses and converts the input batch into the format
|
|
21
|
+
expected by the `model_forward` method.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self, system_prompt: str | None, output_transforms: Callable | None = None
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Creates a new model instance.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
system_prompt: The system prompt to use for the model (optional).
|
|
31
|
+
output_transforms: Optional transforms to apply to the output of
|
|
32
|
+
the model's forward pass.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(transforms=output_transforms)
|
|
35
|
+
|
|
36
|
+
self.system_message = ModelSystemMessage(content=system_prompt) if system_prompt else None
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def forward(self, batch: TextImageBatch) -> List[str]:
|
|
40
|
+
"""Forward pass of the model."""
|
|
41
|
+
inputs = self.format_inputs(batch)
|
|
42
|
+
return super().forward(inputs)
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def format_inputs(self, batch: TextImageBatch) -> Any:
|
|
46
|
+
"""Converts the inputs into the format expected by the model."""
|
|
47
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Vision backbone helper class."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.core.models.wrappers import base
|
|
9
|
+
from eva.core.utils import factory
|
|
10
|
+
from eva.multimodal.models.networks.registry import model_registry
|
|
11
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFromRegistry(base.BaseModel[TextImageBatch, List[str]]):
|
|
15
|
+
"""Wrapper class for vision backbone models.
|
|
16
|
+
|
|
17
|
+
This class can be used by load backbones available in eva's
|
|
18
|
+
model registry by name. New backbones can be registered by using
|
|
19
|
+
the `@backbone_registry.register(model_name)` decorator.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model_name: str,
|
|
25
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
+
model_extra_kwargs: Dict[str, Any] | None = None,
|
|
27
|
+
transforms: Callable | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes the model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model_name: The name of the model to load.
|
|
33
|
+
model_kwargs: The arguments used for instantiating the model.
|
|
34
|
+
model_extra_kwargs: Extra arguments used for instantiating the model.
|
|
35
|
+
transforms: The transforms to apply to the output tensor
|
|
36
|
+
produced by the model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(transforms=transforms)
|
|
39
|
+
|
|
40
|
+
self._model_name = model_name
|
|
41
|
+
self._model_kwargs = model_kwargs or {}
|
|
42
|
+
self._model_extra_kwargs = model_extra_kwargs or {}
|
|
43
|
+
|
|
44
|
+
self.model = self.load_model()
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def load_model(self) -> nn.Module:
|
|
48
|
+
ModelFromRegistry.__name__ = self._model_name
|
|
49
|
+
|
|
50
|
+
return factory.ModuleFactory(
|
|
51
|
+
registry=model_registry,
|
|
52
|
+
name=self._model_name,
|
|
53
|
+
init_args=self._model_kwargs | self._model_extra_kwargs,
|
|
54
|
+
)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""HuggingFace Vision-Language Model Wrapper."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Callable, Dict, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import transformers
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch import nn
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.language.models.typings import TextBatch
|
|
13
|
+
from eva.language.utils.text import messages as language_message_utils
|
|
14
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
15
|
+
from eva.multimodal.models.wrappers import base
|
|
16
|
+
from eva.multimodal.utils.text import messages as message_utils
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HuggingFaceModel(base.VisionLanguageModel):
|
|
20
|
+
"""Lightweight wrapper for Huggingface VLMs.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model_name_or_path: The name of the model to use.
|
|
24
|
+
model_class: The class of the model to use.
|
|
25
|
+
model_kwargs: Additional model arguments.
|
|
26
|
+
processor_kwargs: Additional processor arguments.
|
|
27
|
+
generation_kwargs: Additional generation arguments.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_name_or_path: str,
|
|
33
|
+
model_class: str,
|
|
34
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
35
|
+
system_prompt: str | None = None,
|
|
36
|
+
processor_kwargs: Dict[str, Any] | None = None,
|
|
37
|
+
generation_kwargs: Dict[str, Any] | None = None,
|
|
38
|
+
):
|
|
39
|
+
"""Initialize the HuggingFace model wrapper.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model_name_or_path: The name or path of the model to use.
|
|
43
|
+
model_class: The class of the model to use.
|
|
44
|
+
model_kwargs: Additional model arguments.
|
|
45
|
+
system_prompt: System prompt to use.
|
|
46
|
+
processor_kwargs: Additional processor arguments.
|
|
47
|
+
generation_kwargs: Additional generation arguments.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(system_prompt=system_prompt)
|
|
50
|
+
|
|
51
|
+
self.model_name_or_path = model_name_or_path
|
|
52
|
+
self.model_kwargs = model_kwargs or {}
|
|
53
|
+
self.base_model_class = model_class
|
|
54
|
+
self.processor_kwargs = processor_kwargs or {}
|
|
55
|
+
self.generation_kwargs = generation_kwargs or {}
|
|
56
|
+
|
|
57
|
+
self.processor = self.load_processor()
|
|
58
|
+
self.model = self.load_model()
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def format_inputs(self, batch: TextImageBatch | TextBatch) -> Dict[str, torch.Tensor]:
|
|
62
|
+
"""Formats inputs for HuggingFace models.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
batch: A batch of text and image inputs.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A dictionary produced by the provided processor following a format like:
|
|
69
|
+
{
|
|
70
|
+
"input_ids": ...,
|
|
71
|
+
"attention_mask": ...,
|
|
72
|
+
"pixel_values": ...
|
|
73
|
+
}
|
|
74
|
+
"""
|
|
75
|
+
message_batch, image_batch, _, _ = self._unpack_batch(batch)
|
|
76
|
+
with_images = image_batch is not None
|
|
77
|
+
|
|
78
|
+
message_batch = language_message_utils.batch_insert_system_message(
|
|
79
|
+
message_batch, self.system_message
|
|
80
|
+
)
|
|
81
|
+
message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
|
|
82
|
+
|
|
83
|
+
if self.processor.chat_template is not None: # type: ignore
|
|
84
|
+
templated_text = [
|
|
85
|
+
self.processor.apply_chat_template( # type: ignore
|
|
86
|
+
message,
|
|
87
|
+
add_generation_prompt=True,
|
|
88
|
+
tokenize=False,
|
|
89
|
+
)
|
|
90
|
+
for message in map(
|
|
91
|
+
functools.partial(
|
|
92
|
+
message_utils.format_huggingface_message,
|
|
93
|
+
with_images=with_images,
|
|
94
|
+
),
|
|
95
|
+
message_batch,
|
|
96
|
+
)
|
|
97
|
+
]
|
|
98
|
+
else:
|
|
99
|
+
raise NotImplementedError("Currently only chat models are supported.")
|
|
100
|
+
|
|
101
|
+
processor_inputs = {
|
|
102
|
+
"text": templated_text,
|
|
103
|
+
"return_tensors": "pt",
|
|
104
|
+
**self.processor_kwargs,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
if with_images:
|
|
108
|
+
processor_inputs["image"] = [[image] for image in image_batch]
|
|
109
|
+
|
|
110
|
+
return self.processor(**processor_inputs).to(self.model.device) # type: ignore
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def model_forward(self, batch: Dict[str, torch.Tensor]) -> List[str]:
|
|
114
|
+
"""Generates text output from the model. Is called by the `generate` method.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
batch: A dictionary containing the input data, which may include:
|
|
118
|
+
- "text": List of messages formatted for the model.
|
|
119
|
+
- "image": List of image tensors.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
A dictionary containing the processed input and the model's output.
|
|
123
|
+
"""
|
|
124
|
+
output = self.model.generate(**batch, **self.generation_kwargs) # type: ignore
|
|
125
|
+
return self._decode_output(output, batch["input_ids"].shape[-1])
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def load_model(self) -> nn.Module:
|
|
129
|
+
"""Setting up the model. Used for delayed model initialization.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If the model class is not found in transformers or if the model
|
|
133
|
+
does not support gradient checkpointing but it is enabled.
|
|
134
|
+
"""
|
|
135
|
+
logger.info(f"Configuring model: {self.model_name_or_path}")
|
|
136
|
+
if hasattr(transformers, self.base_model_class):
|
|
137
|
+
model_class = getattr(transformers, self.base_model_class)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Model class {self.base_model_class} not found in transformers")
|
|
140
|
+
|
|
141
|
+
model = model_class.from_pretrained(self.model_name_or_path, **self.model_kwargs)
|
|
142
|
+
|
|
143
|
+
if not hasattr(model, "generate"):
|
|
144
|
+
raise ValueError(f"Model {self.model_name_or_path} does not support generation. ")
|
|
145
|
+
|
|
146
|
+
return model
|
|
147
|
+
|
|
148
|
+
def load_processor(self) -> Callable:
|
|
149
|
+
"""Initialize the processor."""
|
|
150
|
+
return transformers.AutoProcessor.from_pretrained(
|
|
151
|
+
self.model_name_or_path,
|
|
152
|
+
**self.processor_kwargs,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def _unpack_batch(self, batch: TextImageBatch | TextBatch) -> tuple:
|
|
156
|
+
if isinstance(batch, TextImageBatch):
|
|
157
|
+
return batch.text, batch.image, batch.target, batch.metadata
|
|
158
|
+
return batch.text, None, batch.target, batch.metadata
|
|
159
|
+
|
|
160
|
+
def _decode_output(self, output: torch.Tensor, instruction_length: int) -> List[str]:
|
|
161
|
+
"""Decode the model's batch output to text.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
output: The raw output from the model.
|
|
165
|
+
instruction_length: The length of the instruction in the input.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A list of decoded text responses.
|
|
169
|
+
"""
|
|
170
|
+
decoded_input = self.processor.batch_decode( # type: ignore
|
|
171
|
+
output[:, :instruction_length], skip_special_tokens=True
|
|
172
|
+
)
|
|
173
|
+
decoded_output = self.processor.batch_decode( # type: ignore
|
|
174
|
+
output[:, instruction_length:], skip_special_tokens=True
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
logger.debug(f"Decoded input: {decoded_input}")
|
|
178
|
+
logger.debug(f"Decoded output: {decoded_output}")
|
|
179
|
+
|
|
180
|
+
return decoded_output
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""LiteLLM vision-language model wrapper."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.language.models import wrappers as language_wrappers
|
|
9
|
+
from eva.language.utils.text import messages as language_message_utils
|
|
10
|
+
from eva.multimodal.models.typings import TextImageBatch
|
|
11
|
+
from eva.multimodal.models.wrappers import base
|
|
12
|
+
from eva.multimodal.utils.text import messages as message_utils
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LiteLLMModel(base.VisionLanguageModel):
|
|
16
|
+
"""Wrapper class for LiteLLM vision-language models."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model_name: str,
|
|
21
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
22
|
+
system_prompt: str | None = None,
|
|
23
|
+
log_level: int | None = logging.INFO,
|
|
24
|
+
):
|
|
25
|
+
"""Initialize the LiteLLM Wrapper.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model_name: The name of the model to use.
|
|
29
|
+
model_kwargs: Additional keyword arguments to pass during
|
|
30
|
+
generation (e.g., `temperature`, `max_tokens`).
|
|
31
|
+
system_prompt: The system prompt to use (optional).
|
|
32
|
+
log_level: Optional logging level for LiteLLM. Defaults to WARNING.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(system_prompt=system_prompt)
|
|
35
|
+
|
|
36
|
+
self.language_model = language_wrappers.LiteLLMModel(
|
|
37
|
+
model_name=model_name,
|
|
38
|
+
model_kwargs=model_kwargs,
|
|
39
|
+
system_prompt=system_prompt,
|
|
40
|
+
log_level=log_level,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def format_inputs(self, batch: TextImageBatch) -> List[List[Dict[str, Any]]]:
|
|
45
|
+
message_batch, image_batch, _, _ = TextImageBatch(*batch)
|
|
46
|
+
|
|
47
|
+
message_batch = language_message_utils.batch_insert_system_message(
|
|
48
|
+
message_batch, self.system_message
|
|
49
|
+
)
|
|
50
|
+
message_batch = list(map(language_message_utils.combine_system_messages, message_batch))
|
|
51
|
+
|
|
52
|
+
return list(map(message_utils.format_litellm_message, message_batch, image_batch))
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def model_forward(self, batch: List[List[Dict[str, Any]]]) -> List[str]:
|
|
56
|
+
return self.language_model.model_forward(batch)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Multimodal utilities API."""
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Image encoding utilities."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import io
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms.v2 import functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def encode_image(image: tv_tensors.Image, encoding: Literal["base64"]) -> str:
|
|
12
|
+
"""Encodes an image tensor into a string format.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
image: The image tensor to encode.
|
|
16
|
+
encoding: The encoding format to use. Currently only supports "base64".
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
An encoded string representation of the image.
|
|
20
|
+
"""
|
|
21
|
+
match encoding:
|
|
22
|
+
case "base64":
|
|
23
|
+
image_bytes = io.BytesIO()
|
|
24
|
+
F.to_pil_image(image).save(image_bytes, format="PNG", optimize=True)
|
|
25
|
+
image_bytes.seek(0)
|
|
26
|
+
return base64.b64encode(image_bytes.getvalue()).decode("utf-8")
|
|
27
|
+
case _:
|
|
28
|
+
raise ValueError(f"Unsupported encoding type: {encoding}. Supported: 'base64'")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Multimodal text utilities API."""
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Message formatting utilities for multimodal models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
from torchvision import tv_tensors
|
|
6
|
+
|
|
7
|
+
from eva.language import utils as language_utils
|
|
8
|
+
from eva.language.data.messages import MessageSeries, Role
|
|
9
|
+
from eva.multimodal.utils import image as image_utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def format_huggingface_message(
|
|
13
|
+
message: MessageSeries, with_images: bool = False
|
|
14
|
+
) -> List[Dict[str, Any]]:
|
|
15
|
+
"""Formats a message series into a format suitable for Huggingface models."""
|
|
16
|
+
if not with_images:
|
|
17
|
+
return language_utils.format_chat_message(message)
|
|
18
|
+
|
|
19
|
+
formatted_message = []
|
|
20
|
+
for item in message:
|
|
21
|
+
if item.role == Role.SYSTEM:
|
|
22
|
+
formatted_message += language_utils.format_chat_message([item])
|
|
23
|
+
else:
|
|
24
|
+
formatted_message.append(
|
|
25
|
+
{
|
|
26
|
+
"role": item.role,
|
|
27
|
+
"content": [
|
|
28
|
+
{
|
|
29
|
+
"type": "text",
|
|
30
|
+
"text": str(item.content),
|
|
31
|
+
},
|
|
32
|
+
{"type": "image"},
|
|
33
|
+
],
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
return formatted_message
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def format_litellm_message(
|
|
40
|
+
message: MessageSeries, image: tv_tensors.Image | None
|
|
41
|
+
) -> List[Dict[str, Any]]:
|
|
42
|
+
"""Format a message series for LiteLLM API.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
message: The message series to format.
|
|
46
|
+
image: Optional image to include in the message.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A list of formatted message dictionaries.
|
|
50
|
+
"""
|
|
51
|
+
if image is None:
|
|
52
|
+
return language_utils.format_chat_message(message)
|
|
53
|
+
|
|
54
|
+
formatted_message = []
|
|
55
|
+
for item in message:
|
|
56
|
+
if item.role == Role.SYSTEM:
|
|
57
|
+
formatted_message += language_utils.format_chat_message([item])
|
|
58
|
+
else:
|
|
59
|
+
formatted_message.append(
|
|
60
|
+
{
|
|
61
|
+
"role": item.role,
|
|
62
|
+
"content": [
|
|
63
|
+
{
|
|
64
|
+
"type": "text",
|
|
65
|
+
"text": str(item.content),
|
|
66
|
+
},
|
|
67
|
+
{
|
|
68
|
+
"type": "image_url",
|
|
69
|
+
"image_url": {
|
|
70
|
+
"url": (
|
|
71
|
+
f"data:image/png;base64,"
|
|
72
|
+
f"{image_utils.encode_image(image, encoding='base64')}"
|
|
73
|
+
)
|
|
74
|
+
},
|
|
75
|
+
},
|
|
76
|
+
],
|
|
77
|
+
}
|
|
78
|
+
)
|
|
79
|
+
return formatted_message
|
|
@@ -61,6 +61,13 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
61
61
|
]
|
|
62
62
|
"""Test resources."""
|
|
63
63
|
|
|
64
|
+
_expected_length = {
|
|
65
|
+
"train": 262144,
|
|
66
|
+
"val": 32768,
|
|
67
|
+
"test": 32768,
|
|
68
|
+
}
|
|
69
|
+
"""Expected dataset length for each split."""
|
|
70
|
+
|
|
64
71
|
_license: str = (
|
|
65
72
|
"Creative Commons Zero v1.0 Universal (https://choosealicense.com/licenses/cc0-1.0/)"
|
|
66
73
|
)
|
|
@@ -113,14 +120,9 @@ class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
|
113
120
|
|
|
114
121
|
@override
|
|
115
122
|
def validate(self) -> None:
|
|
116
|
-
expected_length = {
|
|
117
|
-
"train": 262144,
|
|
118
|
-
"val": 32768,
|
|
119
|
-
"test": 32768,
|
|
120
|
-
}
|
|
121
123
|
_validators.check_dataset_integrity(
|
|
122
124
|
self,
|
|
123
|
-
length=
|
|
125
|
+
length=self._expected_length.get(self._split, 0),
|
|
124
126
|
n_classes=2,
|
|
125
127
|
first_and_last_labels=("no_tumor", "tumor"),
|
|
126
128
|
)
|
|
@@ -13,10 +13,11 @@ from eva.vision.data.transforms.intensity import (
|
|
|
13
13
|
RandShiftIntensity,
|
|
14
14
|
ScaleIntensityRange,
|
|
15
15
|
)
|
|
16
|
-
from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
|
|
16
|
+
from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Resize, Spacing
|
|
17
17
|
from eva.vision.data.transforms.utility import EnsureChannelFirst
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
|
+
"Resize",
|
|
20
21
|
"ResizeAndCrop",
|
|
21
22
|
"Squeeze",
|
|
22
23
|
"CropForeground",
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Transforms for spatial operations."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.transforms.spatial.flip import RandFlip
|
|
4
|
+
from eva.vision.data.transforms.spatial.resize import Resize
|
|
4
5
|
from eva.vision.data.transforms.spatial.rotate import RandRotate90
|
|
5
6
|
from eva.vision.data.transforms.spatial.spacing import Spacing
|
|
6
7
|
|
|
7
|
-
__all__ = ["Spacing", "RandFlip", "RandRotate90"]
|
|
8
|
+
__all__ = ["Spacing", "RandFlip", "RandRotate90", "Resize"]
|